1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Alexander Heinecke (Intel Corp.)
10 ******************************************************************************/
11 #include "libxsmm_dnn_fusedbatchnorm_forward.h"
12 #include "libxsmm_main.h"
13 
14 #if defined(LIBXSMM_OFFLOAD_TARGET)
15 # pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
16 #endif
17 #include <math.h>
18 #if defined(LIBXSMM_OFFLOAD_TARGET)
19 # pragma offload_attribute(pop)
20 #endif
21 
22 
23 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_f32_c16(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid);
24 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_f32_c32(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid);
25 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_f32_c64(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid);
26 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_bf16_bf16_c16(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid);
27 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_bf16_bf16_c32(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid);
28 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_bf16_bf16_c64(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid);
29 
30 
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)31 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
32 libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_f32_c16(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid)
33 {
34   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
35 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
36   typedef float element_input_type;
37   typedef float element_output_type;
38   typedef float element_stats_type;
39 
40   if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) {
41     status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER;
42   } else {
43     if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN)            ||
44          (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED)    ) {
45 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
46     } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) {
47 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
48 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
49 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
50 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
51 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
52     } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) {
53 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
54 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
55 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
56 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
57 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
58     } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) {
59 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
60 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
61 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
62     } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) {
63 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
64 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
65 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
66     } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) {
67 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
68 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
69 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
70     } else {
71       status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION;
72     }
73   }
74 #else /* should not happen */
75   LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
76   status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
77 #endif
78   return status;
79 }
80 
81 
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)82 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
83 libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_f32_c32(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid)
84 {
85   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
86 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
87   typedef float element_input_type;
88   typedef float element_output_type;
89   typedef float element_stats_type;
90 
91   if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) {
92     status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER;
93   } else {
94     if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN)            ||
95          (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED)    ) {
96 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
97     } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) {
98 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
99 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
100 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
101 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
102 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
103     } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) {
104 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
105 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
106 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
107 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
108 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
109     } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) {
110 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
111 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
112 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
113     } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) {
114 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
115 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
116 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
117     } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) {
118 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
119 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
120 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
121     } else {
122       status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION;
123     }
124   }
125 #else /* should not happen */
126   LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
127   status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
128 #endif
129   return status;
130 }
131 
132 
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)133 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
134 libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_f32_c64(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid)
135 {
136   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
137 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
138   typedef float element_input_type;
139   typedef float element_output_type;
140   typedef float element_stats_type;
141 
142   if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) {
143     status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER;
144   } else {
145     if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN)            ||
146          (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED)    ) {
147 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
148     } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) {
149 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
150 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
151 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
152 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
153 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
154     } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) {
155 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
156 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
157 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
158 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
159 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
160     } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) {
161 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
162 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
163 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
164     } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) {
165 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
166 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
167 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
168     } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) {
169 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
170 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
171 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
172     } else {
173       status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION;
174     }
175   }
176 #else /* should not happen */
177   LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
178   status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
179 #endif
180   return status;
181 }
182 
183 
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)184 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
185 libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_bf16_bf16_c16(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid)
186 {
187   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
188 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
189   typedef libxsmm_bfloat16 element_input_type;
190   typedef libxsmm_bfloat16 element_output_type;
191   typedef float element_stats_type;
192 
193 # define LIBXSMM_DNN_FUSEDBN_FWD_BF16
194   if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) {
195     status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER;
196   } else {
197     if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN)            ||
198          (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED)    ) {
199 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
200     } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) {
201 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
202 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
203 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
204 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
205 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
206     } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) {
207 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
208 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
209 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
210 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
211 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
212     } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) {
213 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
214 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
215 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
216     } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) {
217 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
218 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
219 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
220     } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) {
221 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
222 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
223 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
224     } else {
225       status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION;
226     }
227   }
228 # undef LIBXSMM_DNN_FUSEDBN_FWD_BF16
229 #else /* should not happen */
230   LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
231   status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
232 #endif
233   return status;
234 }
235 
236 
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)237 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
238 libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_bf16_bf16_c32(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid)
239 {
240   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
241 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
242   typedef libxsmm_bfloat16 element_input_type;
243   typedef libxsmm_bfloat16 element_output_type;
244   typedef float element_stats_type;
245 
246 # define LIBXSMM_DNN_FUSEDBN_FWD_BF16
247   if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) {
248     status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER;
249   } else {
250     if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN)            ||
251          (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED)    ) {
252 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
253     } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) {
254 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
255 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
256 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
257 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
258 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
259     } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) {
260 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
261 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
262 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
263 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
264 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
265     } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) {
266 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
267 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
268 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
269     } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) {
270 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
271 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
272 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
273     } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) {
274 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
275 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
276 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
277     } else {
278       status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION;
279     }
280   }
281 # undef LIBXSMM_DNN_FUSEDBN_FWD_BF16
282 #else /* should not happen */
283   LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
284   status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
285 #endif
286   return status;
287 }
288 
289 
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)290 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
291 libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_bf16_bf16_c64(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid)
292 {
293   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
294 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
295   typedef libxsmm_bfloat16 element_input_type;
296   typedef libxsmm_bfloat16 element_output_type;
297   typedef float element_stats_type;
298 
299 # define LIBXSMM_DNN_FUSEDBN_FWD_BF16
300   if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) {
301     status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER;
302   } else {
303     if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN)            ||
304          (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED)    ) {
305 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
306     } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) {
307 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
308 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
309 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
310 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
311 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
312     } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) {
313 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
314 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
315 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
316 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
317 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
318     } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) {
319 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
320 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
321 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
322     } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) {
323 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
324 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
325 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
326     } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) {
327 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
328 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
329 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
330     } else {
331       status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION;
332     }
333   }
334 # undef LIBXSMM_DNN_FUSEDBN_FWD_BF16
335 #else /* should not happen */
336   LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
337   status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
338 #endif
339   return status;
340 }
341 
342 
libxsmm_dnn_fusedbatchnorm_st_fwd_custom(libxsmm_dnn_fusedbatchnorm * handle,int start_thread,int tid)343 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid)
344 {
345   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
346 
347   /* check if all required tensors are bound */
348   if ( handle->reg_input == 0 || handle->reg_output == 0 ||
349        handle->reg_beta == 0  || handle->reg_gamma == 0  ||
350        handle->expvalue == 0  || handle->rcpstddev == 0  || handle->variance == 0 ) {
351     status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
352     return status;
353   }
354   if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BN) > 0 ) {
355     if ( handle->scratch == 0 ) {
356       status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
357       return status;
358     }
359   }
360   if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) > 0 ) {
361     if ( handle->reg_add == 0 ) {
362       status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
363       return status;
364     }
365   }
366   if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) > 0 ) {
367     if ( handle->relumask == 0 ) {
368       status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
369       return status;
370     }
371   }
372 
373   /* check if we are on an AVX512 platform */
374 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
375   if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) &&
376        (handle->ofmblock == 16) ) {
377     if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
378       status = libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_f32_c16( handle, start_thread, tid );
379     } else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) {
380       status = libxsmm_dnn_fusedbatchnorm_st_fwd_custom_bf16_bf16_c16( handle, start_thread, tid );
381     } else {
382       status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
383       return status;
384     }
385   } else if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) &&
386        (handle->ofmblock == 32) ) {
387     if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
388       status = libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_f32_c32( handle, start_thread, tid );
389     } else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) {
390       status = libxsmm_dnn_fusedbatchnorm_st_fwd_custom_bf16_bf16_c32( handle, start_thread, tid );
391     } else {
392       status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
393       return status;
394     }
395   } else if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) &&
396        (handle->ofmblock == 64) ) {
397     if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
398       status = libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_f32_c64( handle, start_thread, tid );
399     } else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) {
400       status = libxsmm_dnn_fusedbatchnorm_st_fwd_custom_bf16_bf16_c64( handle, start_thread, tid );
401     } else {
402       status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
403       return status;
404     }
405   } else
406 #endif
407   {
408     if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
409       typedef float element_input_type;
410       typedef float element_output_type;
411       typedef float element_stats_type;
412 
413       if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) {
414         status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER;
415       } else {
416         if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN)            ||
417              (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED)    ) {
418 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c"
419         } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) {
420 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
421 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
422 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c"
423 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
424 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
425         } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) {
426 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
427 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
428 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c"
429 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
430 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
431         } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) {
432 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
433 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c"
434 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
435         } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) {
436 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
437 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c"
438 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
439         } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) {
440 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
441 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c"
442 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
443         } else {
444           status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION;
445         }
446       }
447     } else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) {
448       typedef libxsmm_bfloat16 element_input_type;
449       typedef libxsmm_bfloat16 element_output_type;
450       typedef float element_stats_type;
451 
452 # define LIBXSMM_DNN_FUSEDBN_FWD_BF16
453       if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) {
454         status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER;
455       } else {
456         if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN)            ||
457              (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED)    ) {
458 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c"
459         } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) {
460 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
461 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
462 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c"
463 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
464 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
465         } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) {
466 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
467 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
468 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c"
469 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
470 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
471         } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) {
472 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
473 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c"
474 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
475         } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) {
476 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
477 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c"
478 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
479         } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) {
480 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
481 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c"
482 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
483         } else {
484           status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION;
485         }
486       }
487 # undef LIBXSMM_DNN_FUSEDBN_FWD_BF16
488     } else {
489       status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
490       return status;
491     }
492   }
493 
494   return status;
495 }
496 
497 
libxsmm_dnn_fusedbatchnorm_st_fwd_nhwc(libxsmm_dnn_fusedbatchnorm * handle,int start_thread,int tid)498 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_nhwc(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid)
499 {
500   libxsmm_dnn_err_t status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
501   LIBXSMM_UNUSED( handle );
502   LIBXSMM_UNUSED( start_thread );
503   LIBXSMM_UNUSED( tid );
504   return status;
505 }
506 
507 
libxsmm_dnn_fusedbatchnorm_reduce_stats_st_fwd_custom(libxsmm_dnn_fusedbatchnorm ** handles,int num_handles,int start_thread,int tid)508 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_reduce_stats_st_fwd_custom(libxsmm_dnn_fusedbatchnorm** handles, int num_handles, int start_thread, int tid)
509 {
510   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
511   int l_count;
512 
513   /* check if all required tensors are bound */
514   for ( l_count = 0; l_count < num_handles; ++l_count ) {
515     if ( handles[l_count]->expvalue == 0  || handles[l_count]->rcpstddev == 0  || handles[l_count]->variance == 0 || handles[l_count]->scratch == 0 ) {
516       status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
517       return status;
518     }
519   }
520 
521 #if 0
522   /* check if we are on an AVX512 platform */
523   if ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) {
524     status = libxsmm_dnn_fusedbatchnorm_reduce_stats_st_fwd_custom_avx512( handles, num_handles, start_thread, tid );
525   } else
526 #endif
527   {
528     const int nImg = handles[0]->desc.partN;
529     const int nBlocksFm = handles[0]->blocksifm;
530     const int nFmBlock = handles[0]->ifmblock;
531     /* computing first logical thread */
532     const int ltid = tid - start_thread;
533     /* number of tasks that could be run in parallel */
534     const int work2 = nBlocksFm;
535     /* compute chunk size */
536     const int chunksize2 = (work2 % handles[0]->desc.threads == 0) ? (work2 / handles[0]->desc.threads) : ((work2 / handles[0]->desc.threads) + 1);
537     /* compute thr_begin and thr_end */
538     const int thr_begin2 = (ltid * chunksize2 < work2) ? (ltid * chunksize2) : work2;
539     const int thr_end2 = ((ltid + 1) * chunksize2 < work2) ? ((ltid + 1) * chunksize2) : work2;
540     int v = 0, fm;
541     const float sqrt_eps = 1e-7f;
542     const float nhw = (float)(handles[0]->desc.fullN * handles[0]->desc.H * handles[0]->desc.W);
543     const float recp_nhw = 1.0f/nhw;
544 
545     LIBXSMM_VLA_DECL(2, float, bmean0,     (float*)handles[0]->expvalue->data,    nFmBlock);
546     LIBXSMM_VLA_DECL(2, float, brstd0,     (float*)handles[0]->rcpstddev->data,   nFmBlock);
547     LIBXSMM_VLA_DECL(2, float, variance0,  (float*)handles[0]->variance->data,    nFmBlock);
548     LIBXSMM_VLA_DECL(3, float, sum_img0,   (float*)handles[0]->scratch,                                                           nImg, nFmBlock);
549     LIBXSMM_VLA_DECL(3, float, sumsq_img0, ((float*)handles[0]->scratch) + ((size_t)nImg * (size_t)nBlocksFm * (size_t)nFmBlock), nImg, nFmBlock);
550 
551     /* lazy barrier init */
552     libxsmm_barrier_init(handles[0]->barrier, ltid);
553 
554     /* now we need to reduce the sum and sum^2, we use the final  */
555     for ( l_count = 1; l_count < num_handles; ++l_count ) {
556       LIBXSMM_VLA_DECL(3, float, sum_imgr,   (float*)handles[l_count]->scratch,                                                           nImg, nFmBlock);
557       LIBXSMM_VLA_DECL(3, float, sumsq_imgr, ((float*)handles[l_count]->scratch) + ((size_t)nImg * (size_t)nBlocksFm * (size_t)nFmBlock), nImg, nFmBlock);
558 
559       for ( fm = thr_begin2; fm < thr_end2; ++fm ) {
560         float* sum_img0_ptr   = &LIBXSMM_VLA_ACCESS(3, sum_img0,   fm, 0, 0, nImg, nFmBlock);
561         float* sumsq_img0_ptr = &LIBXSMM_VLA_ACCESS(3, sumsq_img0, fm, 0, 0, nImg, nFmBlock);
562         float* sum_imgr_ptr   = &LIBXSMM_VLA_ACCESS(3, sum_imgr,   fm, 0, 0, nImg, nFmBlock);
563         float* sumsq_imgr_ptr = &LIBXSMM_VLA_ACCESS(3, sumsq_imgr, fm, 0, 0, nImg, nFmBlock);
564 
565         LIBXSMM_PRAGMA_SIMD
566         for ( v=0; v < nFmBlock; v++ ) {
567           sum_img0_ptr[v] += sum_imgr_ptr[v];
568           sumsq_img0_ptr[v] += sumsq_imgr_ptr[v];
569         }
570       }
571     }
572 
573     for ( fm = thr_begin2; fm < thr_end2; ++fm ) {
574       float* bmean0_ptr      = &LIBXSMM_VLA_ACCESS(2, bmean0,     fm, 0, nFmBlock);
575       float* brstd0_ptr      = &LIBXSMM_VLA_ACCESS(2, brstd0,     fm, 0, nFmBlock);
576       float* tvar0_ptr       = &LIBXSMM_VLA_ACCESS(2, variance0,  fm, 0, nFmBlock);
577       float* sum_img0_ptr   = &LIBXSMM_VLA_ACCESS(3, sum_img0,   fm, 0, 0, nImg, nFmBlock);
578       float* sumsq_img0_ptr = &LIBXSMM_VLA_ACCESS(3, sumsq_img0, fm, 0, 0, nImg, nFmBlock);
579 
580       LIBXSMM_PRAGMA_SIMD
581       for ( v=0; v < nFmBlock; v++ ) {
582         const float tbmean = (recp_nhw * sum_img0_ptr[v]);
583         const float tbmeansq = tbmean * tbmean;
584         const float tsqbmean = recp_nhw * sumsq_img0_ptr[v];
585         const float tvar     = tsqbmean - tbmeansq;
586         const float tbrstd = (float)(1.0/sqrt((double)tvar + sqrt_eps));
587         bmean0_ptr[v] = tbmean;
588         brstd0_ptr[v] = tbrstd;
589         tvar0_ptr[v] = tvar;
590       }
591     }
592 
593     for ( l_count = 1; l_count < num_handles; ++l_count ) {
594       LIBXSMM_VLA_DECL(2, float, bmeanr,     (float*)handles[l_count]->expvalue->data,    nFmBlock);
595       LIBXSMM_VLA_DECL(2, float, brstdr,     (float*)handles[l_count]->rcpstddev->data,   nFmBlock);
596       LIBXSMM_VLA_DECL(2, float, variancer,  (float*)handles[l_count]->variance->data,    nFmBlock);
597 
598       for ( fm = thr_begin2; fm < thr_end2; ++fm ) {
599         float* bmean0_ptr      = &LIBXSMM_VLA_ACCESS(2, bmean0,     fm, 0, nFmBlock);
600         float* brstd0_ptr      = &LIBXSMM_VLA_ACCESS(2, brstd0,     fm, 0, nFmBlock);
601         float* tvar0_ptr       = &LIBXSMM_VLA_ACCESS(2, variance0,  fm, 0, nFmBlock);
602         float* bmeanr_ptr      = &LIBXSMM_VLA_ACCESS(2, bmeanr,     fm, 0, nFmBlock);
603         float* brstdr_ptr      = &LIBXSMM_VLA_ACCESS(2, brstdr,     fm, 0, nFmBlock);
604         float* tvarr_ptr       = &LIBXSMM_VLA_ACCESS(2, variancer,  fm, 0, nFmBlock);
605 
606         LIBXSMM_PRAGMA_SIMD
607         for ( v=0; v < nFmBlock; v++ ) {
608           bmeanr_ptr[v] = bmean0_ptr[v];
609           brstdr_ptr[v] = brstd0_ptr[v];
610           tvarr_ptr[v] = tvar0_ptr[v];
611         }
612       }
613     }
614 
615     libxsmm_barrier_wait(handles[0]->barrier, ltid);
616   }
617 
618   return status;
619 }
620 
621