1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved. *
3 * This file is part of the LIBXSMM library. *
4 * *
5 * For information on the license, see the LICENSE file. *
6 * Further information: https://github.com/hfp/libxsmm/ *
7 * SPDX-License-Identifier: BSD-3-Clause *
8 ******************************************************************************/
9 /* Alexander Heinecke (Intel Corp.)
10 ******************************************************************************/
11 #include "libxsmm_dnn_fusedbatchnorm_forward.h"
12 #include "libxsmm_main.h"
13
14 #if defined(LIBXSMM_OFFLOAD_TARGET)
15 # pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
16 #endif
17 #include <math.h>
18 #if defined(LIBXSMM_OFFLOAD_TARGET)
19 # pragma offload_attribute(pop)
20 #endif
21
22
23 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_f32_c16(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid);
24 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_f32_c32(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid);
25 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_f32_c64(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid);
26 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_bf16_bf16_c16(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid);
27 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_bf16_bf16_c32(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid);
28 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_bf16_bf16_c64(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid);
29
30
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)31 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
32 libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_f32_c16(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid)
33 {
34 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
35 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
36 typedef float element_input_type;
37 typedef float element_output_type;
38 typedef float element_stats_type;
39
40 if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) {
41 status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER;
42 } else {
43 if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN) ||
44 (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) ) {
45 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
46 } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) {
47 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
48 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
49 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
50 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
51 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
52 } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) {
53 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
54 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
55 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
56 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
57 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
58 } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) {
59 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
60 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
61 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
62 } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) {
63 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
64 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
65 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
66 } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) {
67 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
68 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
69 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
70 } else {
71 status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION;
72 }
73 }
74 #else /* should not happen */
75 LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
76 status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
77 #endif
78 return status;
79 }
80
81
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)82 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
83 libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_f32_c32(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid)
84 {
85 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
86 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
87 typedef float element_input_type;
88 typedef float element_output_type;
89 typedef float element_stats_type;
90
91 if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) {
92 status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER;
93 } else {
94 if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN) ||
95 (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) ) {
96 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
97 } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) {
98 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
99 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
100 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
101 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
102 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
103 } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) {
104 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
105 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
106 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
107 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
108 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
109 } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) {
110 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
111 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
112 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
113 } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) {
114 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
115 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
116 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
117 } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) {
118 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
119 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
120 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
121 } else {
122 status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION;
123 }
124 }
125 #else /* should not happen */
126 LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
127 status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
128 #endif
129 return status;
130 }
131
132
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)133 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
134 libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_f32_c64(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid)
135 {
136 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
137 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
138 typedef float element_input_type;
139 typedef float element_output_type;
140 typedef float element_stats_type;
141
142 if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) {
143 status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER;
144 } else {
145 if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN) ||
146 (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) ) {
147 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
148 } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) {
149 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
150 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
151 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
152 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
153 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
154 } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) {
155 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
156 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
157 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
158 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
159 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
160 } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) {
161 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
162 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
163 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
164 } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) {
165 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
166 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
167 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
168 } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) {
169 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
170 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
171 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
172 } else {
173 status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION;
174 }
175 }
176 #else /* should not happen */
177 LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
178 status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
179 #endif
180 return status;
181 }
182
183
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)184 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
185 libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_bf16_bf16_c16(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid)
186 {
187 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
188 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
189 typedef libxsmm_bfloat16 element_input_type;
190 typedef libxsmm_bfloat16 element_output_type;
191 typedef float element_stats_type;
192
193 # define LIBXSMM_DNN_FUSEDBN_FWD_BF16
194 if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) {
195 status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER;
196 } else {
197 if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN) ||
198 (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) ) {
199 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
200 } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) {
201 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
202 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
203 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
204 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
205 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
206 } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) {
207 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
208 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
209 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
210 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
211 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
212 } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) {
213 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
214 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
215 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
216 } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) {
217 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
218 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
219 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
220 } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) {
221 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
222 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c16_avx512.tpl.c"
223 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
224 } else {
225 status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION;
226 }
227 }
228 # undef LIBXSMM_DNN_FUSEDBN_FWD_BF16
229 #else /* should not happen */
230 LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
231 status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
232 #endif
233 return status;
234 }
235
236
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)237 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
238 libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_bf16_bf16_c32(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid)
239 {
240 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
241 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
242 typedef libxsmm_bfloat16 element_input_type;
243 typedef libxsmm_bfloat16 element_output_type;
244 typedef float element_stats_type;
245
246 # define LIBXSMM_DNN_FUSEDBN_FWD_BF16
247 if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) {
248 status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER;
249 } else {
250 if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN) ||
251 (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) ) {
252 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
253 } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) {
254 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
255 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
256 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
257 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
258 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
259 } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) {
260 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
261 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
262 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
263 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
264 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
265 } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) {
266 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
267 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
268 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
269 } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) {
270 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
271 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
272 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
273 } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) {
274 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
275 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c32_avx512.tpl.c"
276 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
277 } else {
278 status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION;
279 }
280 }
281 # undef LIBXSMM_DNN_FUSEDBN_FWD_BF16
282 #else /* should not happen */
283 LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
284 status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
285 #endif
286 return status;
287 }
288
289
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)290 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
291 libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom_bf16_bf16_c64(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid)
292 {
293 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
294 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
295 typedef libxsmm_bfloat16 element_input_type;
296 typedef libxsmm_bfloat16 element_output_type;
297 typedef float element_stats_type;
298
299 # define LIBXSMM_DNN_FUSEDBN_FWD_BF16
300 if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) {
301 status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER;
302 } else {
303 if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN) ||
304 (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) ) {
305 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
306 } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) {
307 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
308 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
309 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
310 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
311 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
312 } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) {
313 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
314 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
315 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
316 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
317 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
318 } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) {
319 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
320 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
321 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
322 } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) {
323 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
324 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
325 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
326 } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) {
327 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
328 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_bf16_c64_avx512.tpl.c"
329 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
330 } else {
331 status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION;
332 }
333 }
334 # undef LIBXSMM_DNN_FUSEDBN_FWD_BF16
335 #else /* should not happen */
336 LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
337 status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
338 #endif
339 return status;
340 }
341
342
libxsmm_dnn_fusedbatchnorm_st_fwd_custom(libxsmm_dnn_fusedbatchnorm * handle,int start_thread,int tid)343 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_custom(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid)
344 {
345 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
346
347 /* check if all required tensors are bound */
348 if ( handle->reg_input == 0 || handle->reg_output == 0 ||
349 handle->reg_beta == 0 || handle->reg_gamma == 0 ||
350 handle->expvalue == 0 || handle->rcpstddev == 0 || handle->variance == 0 ) {
351 status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
352 return status;
353 }
354 if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_BN) > 0 ) {
355 if ( handle->scratch == 0 ) {
356 status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
357 return status;
358 }
359 }
360 if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) > 0 ) {
361 if ( handle->reg_add == 0 ) {
362 status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
363 return status;
364 }
365 }
366 if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) > 0 ) {
367 if ( handle->relumask == 0 ) {
368 status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
369 return status;
370 }
371 }
372
373 /* check if we are on an AVX512 platform */
374 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
375 if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) &&
376 (handle->ofmblock == 16) ) {
377 if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
378 status = libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_f32_c16( handle, start_thread, tid );
379 } else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) {
380 status = libxsmm_dnn_fusedbatchnorm_st_fwd_custom_bf16_bf16_c16( handle, start_thread, tid );
381 } else {
382 status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
383 return status;
384 }
385 } else if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) &&
386 (handle->ofmblock == 32) ) {
387 if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
388 status = libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_f32_c32( handle, start_thread, tid );
389 } else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) {
390 status = libxsmm_dnn_fusedbatchnorm_st_fwd_custom_bf16_bf16_c32( handle, start_thread, tid );
391 } else {
392 status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
393 return status;
394 }
395 } else if ( ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) &&
396 (handle->ofmblock == 64) ) {
397 if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
398 status = libxsmm_dnn_fusedbatchnorm_st_fwd_custom_f32_f32_c64( handle, start_thread, tid );
399 } else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) {
400 status = libxsmm_dnn_fusedbatchnorm_st_fwd_custom_bf16_bf16_c64( handle, start_thread, tid );
401 } else {
402 status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
403 return status;
404 }
405 } else
406 #endif
407 {
408 if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
409 typedef float element_input_type;
410 typedef float element_output_type;
411 typedef float element_stats_type;
412
413 if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) {
414 status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER;
415 } else {
416 if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN) ||
417 (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) ) {
418 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c"
419 } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) {
420 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
421 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
422 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c"
423 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
424 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
425 } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) {
426 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
427 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
428 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c"
429 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
430 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
431 } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) {
432 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
433 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c"
434 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
435 } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) {
436 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
437 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c"
438 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
439 } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) {
440 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
441 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c"
442 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
443 } else {
444 status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION;
445 }
446 }
447 } else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 ) {
448 typedef libxsmm_bfloat16 element_input_type;
449 typedef libxsmm_bfloat16 element_output_type;
450 typedef float element_stats_type;
451
452 # define LIBXSMM_DNN_FUSEDBN_FWD_BF16
453 if ( handle->desc.fuse_order != LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU ) {
454 status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER;
455 } else {
456 if ( (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BN) ||
457 (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS) || (handle->desc.fuse_ops == LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED) ) {
458 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c"
459 } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU ) {
460 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
461 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
462 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c"
463 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
464 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
465 } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK ) {
466 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
467 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
468 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c"
469 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
470 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
471 } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE) == LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE ) {
472 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
473 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c"
474 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_ELTWISE
475 } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU) == LIBXSMM_DNN_FUSEDBN_OPS_RELU ) {
476 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
477 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c"
478 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU
479 } else if ( (handle->desc.fuse_ops & LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK) == LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK ) {
480 # define LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
481 # include "template/libxsmm_dnn_fusedbatchnorm_st_fwd_custom_generic.tpl.c"
482 # undef LIBXSMM_DNN_FUSEDBN_FWD_ENABLE_RELU_WITH_MASK
483 } else {
484 status = LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION;
485 }
486 }
487 # undef LIBXSMM_DNN_FUSEDBN_FWD_BF16
488 } else {
489 status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
490 return status;
491 }
492 }
493
494 return status;
495 }
496
497
libxsmm_dnn_fusedbatchnorm_st_fwd_nhwc(libxsmm_dnn_fusedbatchnorm * handle,int start_thread,int tid)498 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_st_fwd_nhwc(libxsmm_dnn_fusedbatchnorm* handle, int start_thread, int tid)
499 {
500 libxsmm_dnn_err_t status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
501 LIBXSMM_UNUSED( handle );
502 LIBXSMM_UNUSED( start_thread );
503 LIBXSMM_UNUSED( tid );
504 return status;
505 }
506
507
libxsmm_dnn_fusedbatchnorm_reduce_stats_st_fwd_custom(libxsmm_dnn_fusedbatchnorm ** handles,int num_handles,int start_thread,int tid)508 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fusedbatchnorm_reduce_stats_st_fwd_custom(libxsmm_dnn_fusedbatchnorm** handles, int num_handles, int start_thread, int tid)
509 {
510 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
511 int l_count;
512
513 /* check if all required tensors are bound */
514 for ( l_count = 0; l_count < num_handles; ++l_count ) {
515 if ( handles[l_count]->expvalue == 0 || handles[l_count]->rcpstddev == 0 || handles[l_count]->variance == 0 || handles[l_count]->scratch == 0 ) {
516 status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
517 return status;
518 }
519 }
520
521 #if 0
522 /* check if we are on an AVX512 platform */
523 if ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) {
524 status = libxsmm_dnn_fusedbatchnorm_reduce_stats_st_fwd_custom_avx512( handles, num_handles, start_thread, tid );
525 } else
526 #endif
527 {
528 const int nImg = handles[0]->desc.partN;
529 const int nBlocksFm = handles[0]->blocksifm;
530 const int nFmBlock = handles[0]->ifmblock;
531 /* computing first logical thread */
532 const int ltid = tid - start_thread;
533 /* number of tasks that could be run in parallel */
534 const int work2 = nBlocksFm;
535 /* compute chunk size */
536 const int chunksize2 = (work2 % handles[0]->desc.threads == 0) ? (work2 / handles[0]->desc.threads) : ((work2 / handles[0]->desc.threads) + 1);
537 /* compute thr_begin and thr_end */
538 const int thr_begin2 = (ltid * chunksize2 < work2) ? (ltid * chunksize2) : work2;
539 const int thr_end2 = ((ltid + 1) * chunksize2 < work2) ? ((ltid + 1) * chunksize2) : work2;
540 int v = 0, fm;
541 const float sqrt_eps = 1e-7f;
542 const float nhw = (float)(handles[0]->desc.fullN * handles[0]->desc.H * handles[0]->desc.W);
543 const float recp_nhw = 1.0f/nhw;
544
545 LIBXSMM_VLA_DECL(2, float, bmean0, (float*)handles[0]->expvalue->data, nFmBlock);
546 LIBXSMM_VLA_DECL(2, float, brstd0, (float*)handles[0]->rcpstddev->data, nFmBlock);
547 LIBXSMM_VLA_DECL(2, float, variance0, (float*)handles[0]->variance->data, nFmBlock);
548 LIBXSMM_VLA_DECL(3, float, sum_img0, (float*)handles[0]->scratch, nImg, nFmBlock);
549 LIBXSMM_VLA_DECL(3, float, sumsq_img0, ((float*)handles[0]->scratch) + ((size_t)nImg * (size_t)nBlocksFm * (size_t)nFmBlock), nImg, nFmBlock);
550
551 /* lazy barrier init */
552 libxsmm_barrier_init(handles[0]->barrier, ltid);
553
554 /* now we need to reduce the sum and sum^2, we use the final */
555 for ( l_count = 1; l_count < num_handles; ++l_count ) {
556 LIBXSMM_VLA_DECL(3, float, sum_imgr, (float*)handles[l_count]->scratch, nImg, nFmBlock);
557 LIBXSMM_VLA_DECL(3, float, sumsq_imgr, ((float*)handles[l_count]->scratch) + ((size_t)nImg * (size_t)nBlocksFm * (size_t)nFmBlock), nImg, nFmBlock);
558
559 for ( fm = thr_begin2; fm < thr_end2; ++fm ) {
560 float* sum_img0_ptr = &LIBXSMM_VLA_ACCESS(3, sum_img0, fm, 0, 0, nImg, nFmBlock);
561 float* sumsq_img0_ptr = &LIBXSMM_VLA_ACCESS(3, sumsq_img0, fm, 0, 0, nImg, nFmBlock);
562 float* sum_imgr_ptr = &LIBXSMM_VLA_ACCESS(3, sum_imgr, fm, 0, 0, nImg, nFmBlock);
563 float* sumsq_imgr_ptr = &LIBXSMM_VLA_ACCESS(3, sumsq_imgr, fm, 0, 0, nImg, nFmBlock);
564
565 LIBXSMM_PRAGMA_SIMD
566 for ( v=0; v < nFmBlock; v++ ) {
567 sum_img0_ptr[v] += sum_imgr_ptr[v];
568 sumsq_img0_ptr[v] += sumsq_imgr_ptr[v];
569 }
570 }
571 }
572
573 for ( fm = thr_begin2; fm < thr_end2; ++fm ) {
574 float* bmean0_ptr = &LIBXSMM_VLA_ACCESS(2, bmean0, fm, 0, nFmBlock);
575 float* brstd0_ptr = &LIBXSMM_VLA_ACCESS(2, brstd0, fm, 0, nFmBlock);
576 float* tvar0_ptr = &LIBXSMM_VLA_ACCESS(2, variance0, fm, 0, nFmBlock);
577 float* sum_img0_ptr = &LIBXSMM_VLA_ACCESS(3, sum_img0, fm, 0, 0, nImg, nFmBlock);
578 float* sumsq_img0_ptr = &LIBXSMM_VLA_ACCESS(3, sumsq_img0, fm, 0, 0, nImg, nFmBlock);
579
580 LIBXSMM_PRAGMA_SIMD
581 for ( v=0; v < nFmBlock; v++ ) {
582 const float tbmean = (recp_nhw * sum_img0_ptr[v]);
583 const float tbmeansq = tbmean * tbmean;
584 const float tsqbmean = recp_nhw * sumsq_img0_ptr[v];
585 const float tvar = tsqbmean - tbmeansq;
586 const float tbrstd = (float)(1.0/sqrt((double)tvar + sqrt_eps));
587 bmean0_ptr[v] = tbmean;
588 brstd0_ptr[v] = tbrstd;
589 tvar0_ptr[v] = tvar;
590 }
591 }
592
593 for ( l_count = 1; l_count < num_handles; ++l_count ) {
594 LIBXSMM_VLA_DECL(2, float, bmeanr, (float*)handles[l_count]->expvalue->data, nFmBlock);
595 LIBXSMM_VLA_DECL(2, float, brstdr, (float*)handles[l_count]->rcpstddev->data, nFmBlock);
596 LIBXSMM_VLA_DECL(2, float, variancer, (float*)handles[l_count]->variance->data, nFmBlock);
597
598 for ( fm = thr_begin2; fm < thr_end2; ++fm ) {
599 float* bmean0_ptr = &LIBXSMM_VLA_ACCESS(2, bmean0, fm, 0, nFmBlock);
600 float* brstd0_ptr = &LIBXSMM_VLA_ACCESS(2, brstd0, fm, 0, nFmBlock);
601 float* tvar0_ptr = &LIBXSMM_VLA_ACCESS(2, variance0, fm, 0, nFmBlock);
602 float* bmeanr_ptr = &LIBXSMM_VLA_ACCESS(2, bmeanr, fm, 0, nFmBlock);
603 float* brstdr_ptr = &LIBXSMM_VLA_ACCESS(2, brstdr, fm, 0, nFmBlock);
604 float* tvarr_ptr = &LIBXSMM_VLA_ACCESS(2, variancer, fm, 0, nFmBlock);
605
606 LIBXSMM_PRAGMA_SIMD
607 for ( v=0; v < nFmBlock; v++ ) {
608 bmeanr_ptr[v] = bmean0_ptr[v];
609 brstdr_ptr[v] = brstd0_ptr[v];
610 tvarr_ptr[v] = tvar0_ptr[v];
611 }
612 }
613 }
614
615 libxsmm_barrier_wait(handles[0]->barrier, ltid);
616 }
617
618 return status;
619 }
620
621