1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved. *
3 * This file is part of the LIBXSMM library. *
4 * *
5 * For information on the license, see the LICENSE file. *
6 * Further information: https://github.com/hfp/libxsmm/ *
7 * SPDX-License-Identifier: BSD-3-Clause *
8 ******************************************************************************/
9 /* Hans Pabst, Alexander Heinecke (Intel Corp.)
10 ******************************************************************************/
11 #include <libxsmm_dnn.h>
12 #include "libxsmm_main.h"
13
14 #if defined(LIBXSMM_OFFLOAD_TARGET)
15 # pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
16 #endif
17 #include <math.h>
18 #if defined(_OPENMP)
19 # include <omp.h>
20 #endif
21 #if defined(LIBXSMM_OFFLOAD_TARGET)
22 # pragma offload_attribute(pop)
23 #endif
24
25
libxsmm_dnn_init(int target_arch)26 LIBXSMM_API_INTERN void libxsmm_dnn_init(int target_arch)
27 {
28 LIBXSMM_UNUSED(target_arch);
29 }
30
31
libxsmm_dnn_finalize(void)32 LIBXSMM_API_INTERN void libxsmm_dnn_finalize(void)
33 {
34 }
35
36
libxsmm_dnn_get_feature_map_blocks(int C,int K,int * C_block,int * K_block,int * fm_lp_block,libxsmm_dnn_datatype datatype_in,libxsmm_dnn_datatype datatype_out)37 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_get_feature_map_blocks( int C, int K, int* C_block, int* K_block, int* fm_lp_block, libxsmm_dnn_datatype datatype_in, libxsmm_dnn_datatype datatype_out ) {
38 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
39 int ifmblock = 0;
40 int ofmblock = 0;
41 int lp_block = 0;
42 int tmp_max_c_block = 32;
43 int tmp_max_k_block = 32;
44 int tmp_block = 0;
45
46 /* init libxsmm */
47 LIBXSMM_INIT
48
49 /* C */
50 if (libxsmm_target_archid >= LIBXSMM_X86_AVX512_CORE) {
51 tmp_max_c_block = 64;
52 }
53 if ( C < tmp_max_c_block ) {
54 ifmblock = C;
55 } else {
56 for ( tmp_block = 1; tmp_block <= tmp_max_c_block; tmp_block *= 2 ) {
57 if ( C % tmp_block == 0 ) ifmblock = tmp_block;
58 }
59 }
60
61 /* K */
62 if (libxsmm_target_archid >= LIBXSMM_X86_AVX512_CORE) {
63 tmp_max_k_block = 64;
64 }
65 if ( K < tmp_max_k_block ) {
66 ofmblock = K;
67 } else {
68 for ( tmp_block = 1; tmp_block <= tmp_max_k_block; tmp_block *= 2 ) {
69 if ( K % tmp_block == 0 ) ofmblock = tmp_block;
70 }
71 }
72
73 /* when do we need VNNI format? */
74 if ( (datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) {
75 lp_block = 1;
76 } else if ( (datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) {
77 lp_block = 2;
78 } else if ( (datatype_in == LIBXSMM_DNN_DATATYPE_I16) && ((datatype_out == LIBXSMM_DNN_DATATYPE_I32) || (datatype_out == LIBXSMM_DNN_DATATYPE_F32)) ) {
79 lp_block = 2;
80 } else if (datatype_in == LIBXSMM_DNN_DATATYPE_I8) {
81 lp_block = 4;
82 } else {
83 status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
84 return status;
85 }
86
87 *C_block = ifmblock;
88 *K_block = ofmblock;
89 *fm_lp_block = lp_block;
90
91 return status;
92 }
93
94
libxsmm_dnn_get_error(libxsmm_dnn_err_t code)95 LIBXSMM_API const char* libxsmm_dnn_get_error(libxsmm_dnn_err_t code)
96 {
97 switch (code) {
98 case LIBXSMM_DNN_SUCCESS:
99 return "LIBXSMM DNN Success!";
100 case LIBXSMM_DNN_WARN_FALLBACK:
101 return "LIBXSMM DNN Warning: Falling back to naive code as target is currently not supported by LIBXSMM!";
102 case LIBXSMM_DNN_WARN_RNN_SUBOPTIMAL_N_BLOCKING:
103 return "LIBXSMM DNN Warning: RNN cell suboptimal minibatch blocking!";
104 case LIBXSMM_DNN_WARN_RNN_SUBOPTIMAL_C_BLOCKING:
105 return "LIBXSMM DNN Warning: RNN cell suboptimal input feature blocking!";
106 case LIBXSMM_DNN_WARN_RNN_SUBOPTIMAL_K_BLOCKING:
107 return "LIBXSMM DNN Warning: RNN cell suboptimal output feature blocking!";
108 case LIBXSMM_DNN_WARN_FC_SUBOPTIMAL_N_BLOCKING:
109 return "LIBXSMM DNN Warning: FC layer suboptimal minibatch blocking!";
110 case LIBXSMM_DNN_WARN_FC_SUBOPTIMAL_C_BLOCKING:
111 return "LIBXSMM DNN Warning: FC layer suboptimal input feature blocking!";
112 case LIBXSMM_DNN_WARN_FC_SUBOPTIMAL_K_BLOCKING:
113 return "LIBXSMM DNN Warning: FC layer suboptimal output feature blocking!";
114 case LIBXSMM_DNN_ERR_GENERAL:
115 return "LIBXSMM DNN Error: General error occurred!";
116 case LIBXSMM_DNN_ERR_CREATE_HANDLE:
117 return "LIBXSMM DNN Error: Handle creation failed!";
118 case LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE:
119 return "LIBXSMM DNN Error: Requested datatype is not available!";
120 case LIBXSMM_DNN_ERR_INVALID_BLOCKING:
121 return "LIBXSMM DNN Error: Requested Input/Output buffer size cannot be blocked!";
122 case LIBXSMM_DNN_ERR_INVALID_HANDLE:
123 return "LIBXSMM DNN Error: An invalid handle was provided!";
124 case LIBXSMM_DNN_ERR_DATA_NOT_BOUND:
125 return "LIBXSMM DNN Error: Not all required sources and destinations have been bound to convolution!";
126 case LIBXSMM_DNN_ERR_CREATE_TENSOR:
127 return "LIBXSMM DNN Error: Tensor creation failed!";
128 case LIBXSMM_DNN_ERR_INVALID_TENSOR:
129 return "LIBXSMM DNN Error: Invalid tensor was specified!";
130 case LIBXSMM_DNN_ERR_MISMATCH_TENSOR:
131 return "LIBXSMM DNN Error: Tensor doesn't match handle it should be bind to!";
132 case LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR:
133 return "LIBXSMM DNN Error: Invalid handle or tensor!";
134 case LIBXSMM_DNN_ERR_INVALID_KIND:
135 return "LIBXSMM DNN Error: Invalid convolution kind!";
136 case LIBXSMM_DNN_ERR_INVALID_FORMAT_NCHW:
137 return "LIBXSMM DNN Error: NCHW format is currently not natively supported by LIBXSMM!";
138 case LIBXSMM_DNN_ERR_UNSUPPORTED_DST_FORMAT:
139 return "LIBXSMM DNN Error: Unsupported destination format when copying data!";
140 case LIBXSMM_DNN_ERR_UNSUPPORTED_SRC_FORMAT:
141 return "LIBXSMM DNN Error: Unsupported source format when copying data!";
142 case LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE:
143 return "LIBXSMM DNN Error: Unsupported format when requesting a convolution!";
144 case LIBXSMM_DNN_ERR_INVALID_FORMAT_KCRS:
145 return "LIBXSMM DNN Error: KCRS format is currently not natively supported by LIBXSMM!";
146 case LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL:
147 return "LIBXSMM DNN Error: Invalid format was specified!";
148 case LIBXSMM_DNN_ERR_CREATE_LAYOUT:
149 return "LIBXSMM DNN Error: Layout creation failed!";
150 case LIBXSMM_DNN_ERR_INVALID_LAYOUT:
151 return "LIBXSMM DNN Error: Invalid layout was specified!";
152 case LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH:
153 return "LIBXSMM DNN Error: Unsupported architecture!";
154 case LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED:
155 return "LIBXSMM DNN Error: scratch binding failed as scratch was not allocated!";
156 case LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE:
157 return "LIBXSMM DNN Error: an unknown tensor type was provided!";
158 case LIBXSMM_DNN_ERR_INVALID_ALGO:
159 return "LIBXSMM DNN Error: Invalid algorithm was specified!";
160 case LIBXSMM_DNN_ERR_INVALID_PADDING:
161 return "LIBXSMM DNN Error: Invalid padding was specified!";
162 case LIBXSMM_DNN_ERR_TIME_STEPS_TOO_SMALL:
163 return "LIBXSMM DNN Error: time steps should be >= 2 for RNN/LSTM!";
164 case LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS:
165 return "LIBXSMM DNN Error: failed to create internal layout arrays!";
166 case LIBXSMM_DNN_ERR_NOT_IMPLEMENTED:
167 return "LIBXSMM DNN Error: the requested functionality is right now not implemented!";
168 case LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER:
169 return "LIBXSMM DNN Error: the requested order of fusion in batch norm is right now not implemented!";
170 case LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION:
171 return "LIBXSMM DNN Error: the requested fusion in batch norm is right now not implemented!";
172 case LIBXSMM_DNN_ERR_INVALID_FORMAT_FUSEDBN:
173 return "LIBXSMM DNN Error: Unsupported format when requesting a fused batch norm!";
174 case LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING:
175 return "LIBXSMM DNN Error: Unsupported pooling operations was requested!";
176 case LIBXSMM_DNN_ERR_INVALID_FORMAT_FC:
177 return "LIBXSMM DNN Error: Unsupported format when requesting a fullyconnected layer!";
178 case LIBXSMM_DNN_ERR_RNN_INVALID_SEQ_LEN:
179 return "LIBXSMM DNN Error: max sequence length is shorter than sequence length we attempt to set!";
180 case LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_ORDER:
181 return "LIBXSMM DNN Error: the requested order of fusion in group norm is right now not implemented!";
182 case LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_FUSION:
183 return "LIBXSMM DNN Error: the requested fusion in group norm is right now not implemented!";
184 case LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION:
185 return "LIBXSMM DNN Error: the requested fusion in fullyconnected is right now not implemented!";
186 default:
187 return "LIBXSMM DNN Error: Unknown error or warning occurred!";
188 }
189 }
190
191
libxsmm_dnn_typesize(libxsmm_dnn_datatype datatype)192 LIBXSMM_API size_t libxsmm_dnn_typesize(libxsmm_dnn_datatype datatype)
193 {
194 switch (datatype) {
195 case LIBXSMM_DNN_DATATYPE_F32: return 4;
196 case LIBXSMM_DNN_DATATYPE_I32: return 4;
197 case LIBXSMM_DNN_DATATYPE_BF16:return 2;
198 case LIBXSMM_DNN_DATATYPE_I16: return 2;
199 case LIBXSMM_DNN_DATATYPE_I8: return 1;
200 /* no error expected as enumeration really arrives at an enum; compiler-checked */
201 default: return 1;
202 }
203 }
204
205
libxsmm_dnn_get_simd_width(libxsmm_dnn_datatype datatype)206 LIBXSMM_API size_t libxsmm_dnn_get_simd_width(libxsmm_dnn_datatype datatype)
207 {
208 size_t l_cl_width_bytes;
209
210 /* init libxsmm */
211 LIBXSMM_INIT
212
213 if ( libxsmm_target_archid == LIBXSMM_X86_GENERIC ) {
214 l_cl_width_bytes = libxsmm_dnn_typesize(datatype);
215 } else if ( libxsmm_target_archid == LIBXSMM_X86_SSE3 ||
216 libxsmm_target_archid == LIBXSMM_X86_SSE4 ) {
217 l_cl_width_bytes = 16;
218 } else if ( libxsmm_target_archid == LIBXSMM_X86_AVX2 ||
219 libxsmm_target_archid == LIBXSMM_X86_AVX ) {
220 l_cl_width_bytes = 32;
221 } else {
222 l_cl_width_bytes = 64;
223 }
224
225 return l_cl_width_bytes/libxsmm_dnn_typesize(datatype);
226 }
227
228
libxsmm_internal_get_max(float * in_buffer,int length)229 LIBXSMM_API_INLINE float libxsmm_internal_get_max( float* in_buffer, int length ) {
230 float absmax_value = LIBXSMM_ABS(in_buffer[0]);
231 int i = 0;
232 #ifdef _OPENMP
233 LIBXSMM_OMP_VAR(i);
234 # pragma omp parallel private(i)
235 {
236 float my_absmax_value = absmax_value;
237 # pragma omp for
238 for (i = 0; i < length; ++i ) {
239 if (LIBXSMM_ABS(in_buffer[i]) > my_absmax_value) {
240 my_absmax_value = LIBXSMM_ABS(in_buffer[i]);
241 }
242 }
243 # pragma omp critical
244 {
245 if (my_absmax_value > absmax_value) {
246 absmax_value = my_absmax_value;
247 }
248 }
249 }
250 #else
251 for (i = 1; i < length; ++i ) {
252 if (LIBXSMM_ABS(in_buffer[i]) > absmax_value) {
253 absmax_value = LIBXSMM_ABS(in_buffer[i]);
254 }
255 }
256 #endif
257
258 return absmax_value;
259 }
260
261
libxsmm_internal_get_max_exp(float * in_buffer,int length)262 LIBXSMM_API_INLINE unsigned char libxsmm_internal_get_max_exp( float* in_buffer, int length ) {
263 libxsmm_intfloat val_exp;
264 unsigned char max_exp = 0;
265
266 /* bit-wise conversion to int */
267 val_exp.f = libxsmm_internal_get_max( in_buffer, length );
268 /* shift by mantissa to the right and convert to char */
269 max_exp = (unsigned char)((val_exp.ui & LIBXSMM_DNN_MASK_ABS_F32) >> LIBXSMM_DNN_MANT_SZ_F32);
270
271 return max_exp;
272 }
273
274
libxsmm_internal_quantize_scalar_no_scf(float input,unsigned char max_exp,unsigned char add_shift,int round_mode)275 LIBXSMM_API_INLINE short libxsmm_internal_quantize_scalar_no_scf( float input, unsigned char max_exp, unsigned char add_shift, int round_mode ) {
276 libxsmm_intfloat value;
277 unsigned int qvalue = 0;
278 unsigned int mant = 0;
279 unsigned int sign = 0;
280 unsigned char rhs = 0;
281 unsigned char exp_off = 0;
282
283 /* init libxsmm */
284 LIBXSMM_INIT
285
286 /* in case of zero we don't need to do anything */
287 if (LIBXSMM_FEQ(input, 0)) {
288 qvalue = 0;
289 } else {
290 /* let's get a float copy to work on */
291 /* vinp = LIBXSMM_INTRINSICS_MM512_LOAD_PS( in_buffer[i] ); */
292 value.f = input;
293 /* let's compute the offset of the current exp at pos i from max offset, we need to mask the sign bit though */
294 /*__m512i vexp = _mm512_cvtps_epi32(_mm512_getexp_ps (vinp));
295 __m512i vexp_off = _mm512_sub_epi32(maxexpf, vexp);*/
296 exp_off = (unsigned char)(max_exp - ((value.ui & LIBXSMM_DNN_MASK_ABS_F32) >> LIBXSMM_DNN_MANT_SZ_F32));
297 /* cut out mantissa and set leading bit */
298 /*__m512i mmask = _mm512_set1_epi32(LIBXSMM_DNN_MASK_MANT_F32);
299 __m512i vmant = _mm512_or_epi32(_mm512_set1_epi32(0x1 << LIBXSMM_DNN_MANT_SZ_F32), _mm512_and_epi32( _mm512_castps_si512( vinp ), mmask));*/
300 mant = ((0x1 << LIBXSMM_DNN_MANT_SZ_F32) | (value.ui & LIBXSMM_DNN_MASK_MANT_F32));
301 /* extract sign */
302 /* __mmask16 smask = _mm512_cmplt_ps_mask (inp, _mm512_set1_ps(0)); */
303 sign = ((value.ui & LIBXSNN_DNN_MASK_SIGN_F32) >> (LIBXSMM_DNN_SZ_F32-1));
304 /* calculate rhs, be aware of the now explicit leading bit, @TODO add DFP8/4 */
305 rhs = (unsigned char)((LIBXSMM_DNN_MANT_SZ_F32+1) - LIBXSMM_DNN_MANT_DFP16 + exp_off + add_shift);
306 /* some safety, to generate 0 when we fall off quant region, @TODO issue a LIBXSMM Warning that we shifted out the entire mantissa */
307 if (rhs > (LIBXSMM_DNN_MANT_SZ_F32+1)) {
308 rhs = (LIBXSMM_DNN_MANT_SZ_F32+1);
309 }
310 /* finally shift the value into the region we need, this is now a 15-add_rhs bit number for the max value in in_buffer */
311 qvalue = (mant >> rhs);
312 /* handle sign, 2 complement */
313 if ( (sign > 0) && (qvalue > 0) ) {
314 qvalue = (~qvalue + 1);
315 }
316
317 if (round_mode == LIBXSMM_DNN_QUANT_BIAS_ROUND) {
318 /* biased rounding towards next bigger number */
319 /* first let's determine in the original number if we need a bias rounding, @TODO need fix for F64 */
320 int bias_needed = (mant & (0x3 << (rhs-2)));
321 /* apply bias */
322 if (bias_needed > 0) {
323 qvalue++;
324 }
325 } else if (round_mode == LIBXSMM_DNN_QUANT_NEAREST_ROUND) {
326 int nearest_needed = (mant & (0x1 << (rhs-1)));
327 /* apply rounding */
328 if ((nearest_needed > 0) && (rhs > 1)) {
329 qvalue++;
330 }
331 } else if (round_mode == LIBXSMM_DNN_QUANT_STOCH_ROUND) {
332 /* stochastic rounding, as implemented in the IBM paper from 2015, @TODO, fix F64 and DFP8 */
333 const float eps = LIXSMMM_DNN_RES_DFP16;
334 /* coverity[dont_call] */
335 const float r = (float)rand();
336 libxsmm_intfloat fvalue;
337 float p, q;
338 /* masking all bits which will be shifted out */
339 fvalue.ui = value.ui & ((LIBXSMM_DNN_MASK_FULL_F32) << rhs);
340 /* drawing a random number */
341 p = r/((float)RAND_MAX);
342 q = (input - fvalue.f)/eps;
343 /* apply rounding if needed */
344 if ((p + q) > 0.5f) {
345 ++qvalue;
346 }
347 } else {
348 /* do nothing about rounding, just chop */
349 }
350 }
351
352 return (short)qvalue;
353 }
354
355
356 /* @TODO make this routine aware of any int type */
libxsmm_dnn_quantize(float * in_buffer,short * out_buffer,int length,unsigned char add_shift,unsigned char * scf,int round_mode)357 LIBXSMM_API void libxsmm_dnn_quantize( float* in_buffer, short* out_buffer, int length, unsigned char add_shift, unsigned char* scf, int round_mode ) {
358 int i = 0;
359
360 /* init libxsmm */
361 LIBXSMM_INIT
362
363 /* in case we are using FP-Mul based quantization we use a different path for now
364 @TODO let's unify the paths by using the similar vectorization for both */
365 if ( round_mode == LIBXSMM_DNN_QUANT_FPHW_ROUND ) {
366 const float max_value = libxsmm_internal_get_max( in_buffer, length );
367 int maxexp = 0;
368 /* take return value of LIBXSMM_FREXPF to mute static analysis issue */
369 float scfq = LIBXSMM_FREXPF(max_value, &maxexp);
370 maxexp -= (15/*LIBXSMM_DNN_MANT_DFP16?*/ - add_shift);
371 scfq = libxsmm_sexp2_i8i(-maxexp);
372
373 #if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH)
374 if ( length % 16 == 0 ) {
375 __m512 vscfq = _mm512_set1_ps(scfq);
376 #ifdef _OPENMP
377 # pragma omp parallel for private(i)
378 #endif
379 for (i = 0; i < length; i+=16 ) {
380 _mm256_stream_si256( (__m256i *)&(out_buffer[i]), LIBXSMM_INTRINSICS_MM512_QUANTIZE_NEAR_PS_EPI16( &(in_buffer[i]), vscfq ) );
381 }
382 } else {
383 #endif
384 #ifdef _OPENMP
385 # pragma omp parallel for private(i)
386 #endif
387 for (i = 0; i < length; ++i ) {
388 out_buffer[i] = (short)LIBXSMM_ROUNDF(in_buffer[i] * scfq);
389 }
390 #if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH)
391 }
392 #endif
393 /* @TODO, we need to potentially fix this unsigned char problem */
394 #if !defined(NDEBUG) /* library code is expected to be mute */
395 if (maxexp > 0) {
396 fprintf(stderr, "error quant fil\n");
397 }
398 #endif
399 *scf = (unsigned char)(-maxexp);
400 } else {
401 /* get max exponent */
402 unsigned char max_exp = libxsmm_internal_get_max_exp( in_buffer, length );
403
404 /* if we go for stochastic rounding, let's initialize random seed */
405 if ( round_mode == LIBXSMM_DNN_QUANT_STOCH_ROUND ) {
406 srand(libxsmm_timer_tick() % ((unsigned int)-1));
407 }
408
409 #ifdef _OPENMP
410 # pragma omp parallel for private(i)
411 #endif
412 for (i = 0; i < length; ++i ) {
413 out_buffer[i] = libxsmm_internal_quantize_scalar_no_scf( in_buffer[i], max_exp, add_shift, round_mode );
414 }
415
416 *scf = (unsigned char)(14 - add_shift - (max_exp - 127));
417 }
418 }
419
420
libxsmm_dnn_quantize_act(float * in_buffer,short * out_buffer,unsigned int N,unsigned int C,unsigned int H,unsigned int W,unsigned int cblk_f32,unsigned int cblk_i16,unsigned int lp_blk,unsigned char add_shift,unsigned char * scf,int round_mode)421 LIBXSMM_API void libxsmm_dnn_quantize_act( float* in_buffer, short* out_buffer, unsigned int N, unsigned int C, unsigned int H, unsigned int W, unsigned int cblk_f32, unsigned int cblk_i16, unsigned int lp_blk, unsigned char add_shift, unsigned char* scf, int round_mode ) {
422 LIBXSMM_VLA_DECL(5, const float, in, in_buffer, C/cblk_f32, H, W, cblk_f32);
423 LIBXSMM_VLA_DECL(6, short, out, out_buffer, C/(cblk_i16*lp_blk), H, W, cblk_i16, lp_blk);
424 const unsigned int cblk = C/(cblk_i16*lp_blk);
425 int i1 = 0, i2 = 0, i3 = 0, i4 = 0, i5, i6;
426
427 /* init libxsmm */
428 LIBXSMM_INIT
429
430 /* some quick and dirty checks */
431 assert((C % cblk_f32) == 0);
432 assert((C % cblk_i16) == 0);
433
434 /* in case we are using FP-Mul based quantization we use a different path for now
435 @TODO let's unify the paths by using the similar vectorization for both */
436 if ( round_mode == LIBXSMM_DNN_QUANT_FPHW_ROUND ) {
437 const float max_value = libxsmm_internal_get_max( in_buffer, N*C*H*W );
438 int maxexp = 0;
439 /* take return value of LIBXSMM_FREXPF to mute static analysis issue */
440 float scfq = LIBXSMM_FREXPF(max_value, &maxexp);
441 maxexp -= (15/*LIBXSMM_DNN_MANT_DFP16?*/ - add_shift);
442 scfq = libxsmm_sexp2_i8i(-maxexp);
443
444 #if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH)
445 if ( (cblk_f32 == 16) && (cblk_i16*lp_blk == 16) ) {
446 __m512 vscfq = _mm512_set1_ps(scfq);
447 #ifdef _OPENMP
448 LIBXSMM_OMP_VAR(i1);
449 # pragma omp parallel for private(i1)
450 #endif
451 for (i1 = 0; i1 < (int)(N*C*H*W); i1 += 16 ) {
452 _mm256_stream_si256( (__m256i *)&(out_buffer[i1]), LIBXSMM_INTRINSICS_MM512_QUANTIZE_NEAR_PS_EPI16( &(in_buffer[i1]), vscfq ) );
453 }
454 } else {
455 #endif
456 #ifdef _OPENMP
457 LIBXSMM_OMP_VAR(i1); LIBXSMM_OMP_VAR(i2); LIBXSMM_OMP_VAR(i3); LIBXSMM_OMP_VAR(i4); LIBXSMM_OMP_VAR(i5); LIBXSMM_OMP_VAR(i6);
458 # pragma omp parallel for private(i1, i2, i3, i4, i5, i6) LIBXSMM_OPENMP_COLLAPSE(4)
459 #endif
460 for (i1 = 0; i1 < (int)N; ++i1 ) {
461 for (i2 = 0; i2 < (int)cblk; ++i2 ) {
462 for (i3 = 0; i3 < (int)H; ++i3 ) {
463 for (i4 = 0; i4 < (int)W; ++i4 ) {
464 for (i5 = 0; i5 < (int)cblk_i16; ++i5 ) {
465 for (i6 = 0; i6 < (int)lp_blk; ++i6 ) {
466 const int fi1 = i1;
467 const int fi2 = ((i2*cblk_i16*lp_blk)+(i5*lp_blk)+i6)/cblk_f32;
468 const int fi3 = i3;
469 const int fi4 = i4;
470 const int fi5 = ((i2*cblk_i16*lp_blk)+(i5*lp_blk)+i6)%cblk_f32;
471 LIBXSMM_VLA_ACCESS(6, out, i1, i2, i3, i4, i5, i6, cblk, H, W, cblk_i16, lp_blk) = (short)LIBXSMM_ROUNDF(
472 LIBXSMM_VLA_ACCESS(5, in, fi1, fi2, fi3, fi4, fi5, C / cblk_f32, H, W, cblk_f32) * scfq);
473 }
474 }
475 }
476 }
477 }
478 }
479 #if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH)
480 }
481 #endif
482 /* @TODO, we need to potentially fix this unsigned char problem */
483 #if !defined(NDEBUG) /* library code is expected to be mute */
484 if (maxexp > 0) {
485 fprintf(stderr, "error quant act\n");
486 }
487 #endif
488 *scf = (unsigned char)(-maxexp);
489 } else {
490 /* get max exponent */
491 unsigned char max_exp = libxsmm_internal_get_max_exp( in_buffer, N*C*H*W );
492
493 /* if we go for stochastic rounding, let's initialize random seed */
494 if ( round_mode == LIBXSMM_DNN_QUANT_STOCH_ROUND ) {
495 srand(libxsmm_timer_tick() % ((unsigned int)-1));
496 }
497
498 #ifdef _OPENMP
499 # pragma omp parallel for private(i1, i2, i3, i4, i5, i6) LIBXSMM_OPENMP_COLLAPSE(4)
500 #endif
501 for (i1 = 0; i1 < (int)N; ++i1 ) {
502 for (i2 = 0; i2 < (int)cblk; ++i2 ) {
503 for (i3 = 0; i3 < (int)H; ++i3 ) {
504 for (i4 = 0; i4 < (int)W; ++i4 ) {
505 for (i5 = 0; i5 < (int)cblk_i16; ++i5 ) {
506 for (i6 = 0; i6 < (int)lp_blk; ++i6 ) {
507 const int fi1 = i1;
508 const int fi2 = ((i2*cblk_i16*lp_blk)+(i5*lp_blk)+i6)/cblk_f32;
509 const int fi3 = i3;
510 const int fi4 = i4;
511 const int fi5 = ((i2*cblk_i16*lp_blk)+(i5*lp_blk)+i6)%cblk_f32;
512 LIBXSMM_VLA_ACCESS(6, out, i1, i2, i3, i4, i5, i6, cblk, H, W, cblk_i16, lp_blk) = libxsmm_internal_quantize_scalar_no_scf(
513 LIBXSMM_VLA_ACCESS(5, in, fi1, fi2, fi3, fi4, fi5, C / cblk_f32, H, W, cblk_f32), max_exp, add_shift, round_mode);
514 }
515 }
516 }
517 }
518 }
519 }
520
521 *scf = (unsigned char)(14 - add_shift - (max_exp - 127));
522 }
523 }
524
525
libxsmm_dnn_quantize_fil(float * in_buffer,short * out_buffer,unsigned int K,unsigned int C,unsigned int R,unsigned int S,unsigned int cblk_f32,unsigned int cblk_i16,unsigned int kblk_f32,unsigned int kblk_i16,unsigned int lp_blk,unsigned char add_shift,unsigned char * scf,int round_mode)526 LIBXSMM_API void libxsmm_dnn_quantize_fil( float* in_buffer, short* out_buffer, unsigned int K, unsigned int C, unsigned int R, unsigned int S, unsigned int cblk_f32, unsigned int cblk_i16, unsigned int kblk_f32, unsigned int kblk_i16, unsigned int lp_blk, unsigned char add_shift, unsigned char* scf, int round_mode ) {
527 LIBXSMM_VLA_DECL(6, const float, in, in_buffer, C/cblk_f32, R, S, cblk_f32, kblk_f32);
528 LIBXSMM_VLA_DECL(7, short, out, out_buffer, C/(cblk_i16*lp_blk), R, S, cblk_i16, kblk_i16, lp_blk);
529 unsigned int cblk = C/(cblk_i16*lp_blk);
530 unsigned int kblk = K/kblk_i16;
531 int i1 = 0, i2 = 0, i3 = 0, i4 = 0, i5, i6, i7;
532
533 /* some quick and dirty checks */
534 assert((C % cblk_f32) == 0);
535 assert((C % (cblk_i16*lp_blk)) == 0);
536 assert((K % kblk_f32) == 0);
537 assert((K % kblk_i16) == 0);
538 assert((lp_blk % 2) == 0);
539
540 /* init libxsmm */
541 LIBXSMM_INIT
542
543 /* in case we are using FP-Mul based quantization we use a different path for now
544 @TODO let's unify the paths by using the similar vectorization for both */
545 if ( round_mode == LIBXSMM_DNN_QUANT_FPHW_ROUND ) {
546 const float max_value = libxsmm_internal_get_max( in_buffer, K*C*R*S );
547 int maxexp = 0;
548 /* take return value of LIBXSMM_FREXPF to mute static analysis issue */
549 float scfq = LIBXSMM_FREXPF(max_value, &maxexp);
550 maxexp -= (15/*LIBXSMM_DNN_MANT_DFP16?*/ - add_shift);
551 scfq = libxsmm_sexp2_i8i(-maxexp);
552
553 #if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH)
554 if ( (kblk_f32 == 16) && (cblk_f32 == 16) && (kblk_i16 == 16) && (cblk_i16*lp_blk == 16) ) {
555 const __m512 vscfq = _mm512_set1_ps(scfq);
556 const __m512i permute_compact_idx = _mm512_set_epi32(15,14,13,12,7,6,5,4,11,10,9,8,3,2,1,0);
557 #ifdef _OPENMP
558 # pragma omp parallel for private(i1, i2, i3, i4, i5) LIBXSMM_OPENMP_COLLAPSE(4)
559 #endif
560 for (i1 = 0; i1 < (int)kblk; ++i1 ) {
561 for (i2 = 0; i2 < (int)cblk; ++i2 ) {
562 for (i3 = 0; i3 < (int)R; ++i3 ) {
563 for (i4 = 0; i4 < (int)S; ++i4 ) {
564 for (i5 = 0; i5 < 16; i5+=2 ) {
565 __m256i even_ch = LIBXSMM_INTRINSICS_MM512_QUANTIZE_NEAR_PS_EPI16(
566 &LIBXSMM_VLA_ACCESS(6, in, i1, i2, i3, i4, i5 + 0, 0, C / cblk_f32, R, S, cblk_f32, kblk_f32), vscfq);
567 __m256i odd_ch = LIBXSMM_INTRINSICS_MM512_QUANTIZE_NEAR_PS_EPI16(
568 &LIBXSMM_VLA_ACCESS(6, in, i1, i2, i3, i4, i5 + 1, 0, C / cblk_f32, R, S, cblk_f32, kblk_f32), vscfq);
569 __m256i compressed_lo = _mm256_unpacklo_epi16(even_ch, odd_ch);
570 __m256i compressed_hi = _mm256_unpackhi_epi16(even_ch, odd_ch);
571 __m512i compact = _mm512_inserti64x4( _mm512_setzero_si512(), compressed_lo, 0);
572 compact = _mm512_inserti64x4(compact, compressed_hi, 1);
573 compact = _mm512_permutexvar_epi32(permute_compact_idx, compact);
574 LIBXSMM_INTRINSICS_MM512_STREAM_SI512(
575 (void*)&LIBXSMM_VLA_ACCESS(7, out, i1, i2, i3, i4, i5 / 2, 0, 0, cblk, R, S, cblk_i16, kblk_i16, lp_blk),
576 compact);
577 }
578 }
579 }
580 }
581 }
582 } else {
583 #endif
584 #ifdef _OPENMP
585 LIBXSMM_OMP_VAR(i1); LIBXSMM_OMP_VAR(i2); LIBXSMM_OMP_VAR(i3); LIBXSMM_OMP_VAR(i4); LIBXSMM_OMP_VAR(i5); LIBXSMM_OMP_VAR(i6); LIBXSMM_OMP_VAR(i7);
586 # pragma omp parallel for private(i1, i2, i3, i4, i5, i6, i7) LIBXSMM_OPENMP_COLLAPSE(4)
587 #endif
588 for (i1 = 0; i1 < (int)kblk; ++i1 ) {
589 for (i2 = 0; i2 < (int)cblk; ++i2 ) {
590 for (i3 = 0; i3 < (int)R; ++i3 ) {
591 for (i4 = 0; i4 < (int)S; ++i4 ) {
592 for (i5 = 0; i5 < (int)cblk_i16; ++i5 ) {
593 for (i6 = 0; i6 < (int)kblk_i16; ++i6 ) {
594 for (i7 = 0; i7 < (int)lp_blk; ++i7 ) {
595 const int fi1 = ((i1*kblk_i16)+i6)/kblk_f32;
596 const int fi2 = ((i2*cblk_i16*lp_blk)+(i5*lp_blk)+i7)/cblk_f32;
597 const int fi3 = i3;
598 const int fi4 = i4;
599 const int fi5 = ((i2*cblk_i16*lp_blk)+(i5*lp_blk)+i7)%cblk_f32;
600 const int fi6 = ((i1*kblk_i16)+i6)%kblk_f32;
601 LIBXSMM_VLA_ACCESS(7, out, i1, i2, i3, i4, i5, i6, i7, cblk, R, S, cblk_i16, kblk_i16, lp_blk) = (short)LIBXSMM_ROUNDF(
602 LIBXSMM_VLA_ACCESS(6, in, fi1, fi2, fi3, fi4, fi5, fi6, C / cblk_f32, R, S, cblk_f32, kblk_f32) * scfq);
603 }
604 }
605 }
606 }
607 }
608 }
609 }
610 #if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH)
611 }
612 #endif
613 /* @TODO, we need to potentially fix this unsigned char problem */
614 #if !defined(NDEBUG) /* library code is expected to be mute */
615 if (maxexp > 0) {
616 fprintf(stderr, "error quant fil\n");
617 }
618 #endif
619 *scf = (unsigned char)(-maxexp);
620 } else {
621 /* get max exponent */
622 unsigned char max_exp = libxsmm_internal_get_max_exp( in_buffer, K*C*R*S );
623
624 /* if we go for stochastic rounding, let's initialize random seed */
625 if ( round_mode == LIBXSMM_DNN_QUANT_STOCH_ROUND ) {
626 srand(libxsmm_timer_tick() % ((unsigned int)-1));
627 }
628
629 #ifdef _OPENMP
630 # pragma omp parallel for private(i1, i2, i3, i4, i5, i6, i7) LIBXSMM_OPENMP_COLLAPSE(4)
631 #endif
632 for (i1 = 0; i1 < (int)kblk; ++i1 ) {
633 for (i2 = 0; i2 < (int)cblk; ++i2 ) {
634 for (i3 = 0; i3 < (int)R; ++i3 ) {
635 for (i4 = 0; i4 < (int)S; ++i4 ) {
636 for (i5 = 0; i5 < (int)cblk_i16; ++i5 ) {
637 for (i6 = 0; i6 < (int)kblk_i16; ++i6 ) {
638 for (i7 = 0; i7 < (int)lp_blk; ++i7 ) {
639 const int fi1 = ((i1*kblk_i16)+i6)/kblk_f32;
640 const int fi2 = ((i2*cblk_i16*lp_blk)+(i5*lp_blk)+i7)/cblk_f32;
641 const int fi3 = i3;
642 const int fi4 = i4;
643 const int fi5 = ((i2*cblk_i16*lp_blk)+(i5*lp_blk)+i7)%cblk_f32;
644 const int fi6 = ((i1*kblk_i16)+i6)%kblk_f32;
645 LIBXSMM_VLA_ACCESS(7, out, i1, i2, i3, i4, i5, i6, i7, cblk, R, S, cblk_i16, kblk_i16, lp_blk) = libxsmm_internal_quantize_scalar_no_scf(
646 LIBXSMM_VLA_ACCESS(6, in, fi1, fi2, fi3, fi4, fi5, fi6, C / cblk_f32, R, S, cblk_f32, kblk_f32), max_exp, add_shift, round_mode);
647 }
648 }
649 }
650 }
651 }
652 }
653 }
654
655 *scf = (unsigned char)(14 - add_shift - (max_exp - 127));
656 }
657 }
658
659
libxsmm_dnn_dequantize(short * in_buffer,float * out_buffer,int length,unsigned char scf)660 LIBXSMM_API void libxsmm_dnn_dequantize( short* in_buffer, float* out_buffer, int length, unsigned char scf ) {
661 const float val_exp = libxsmm_sexp2_i8i(-scf);
662 int i = 0;
663
664 #ifdef _OPENMP
665 # pragma omp parallel for private(i)
666 #endif
667 for ( i = 0; i < length; ++i ) {
668 out_buffer[i] = ((float)in_buffer[i])*val_exp;
669 }
670 }
671
672
libxsmm_truncate_convert_f32_bf16(const float * in,libxsmm_bfloat16 * out,unsigned int length)673 LIBXSMM_API void libxsmm_truncate_convert_f32_bf16(const float* in, libxsmm_bfloat16* out, unsigned int length) {
674 unsigned int i = 0;
675
676 /* truncate buffer to bf16 */
677 for ( i = 0; i < length; ++i ) {
678 libxsmm_bfloat16_hp t;
679
680 t.f = in[i];
681 out[i] = t.i[1];
682 }
683 }
684
685
libxsmm_rnaz_convert_fp32_bf16(const float * in,libxsmm_bfloat16 * out,unsigned int len)686 LIBXSMM_API void libxsmm_rnaz_convert_fp32_bf16(const float* in, libxsmm_bfloat16* out, unsigned int len) {
687 unsigned int i = 0;
688
689 /* truncate buffer to bf16 */
690 for ( i = 0; i < len; ++i ) {
691 unsigned int int_round = 0;
692 unsigned int do_round = 1;
693
694 int_round = *((unsigned int*)&(in[i]));
695
696 /* we don't round NaN and inf */
697 if ( (int_round & 0x7f800000) == 0x7f800000 ) {
698 do_round = 0;
699 }
700
701 /* perform round nearest tie away from zero */
702 if ( do_round != 0 ) {
703 int_round = int_round + 0x00008000;
704 }
705
706 /* create the bf16 value by shifting out the lower 16bits */
707 int_round = int_round >> 16;
708
709 out[i] = (libxsmm_bfloat16)int_round;
710 }
711 }
712
713
libxsmm_rne_convert_fp32_bf16(const float * in,libxsmm_bfloat16 * out,unsigned int len)714 LIBXSMM_API void libxsmm_rne_convert_fp32_bf16(const float* in, libxsmm_bfloat16* out, unsigned int len) {
715 unsigned int i = 0;
716
717 /* truncate buffer to bf16 */
718 for ( i = 0; i < len; ++i ) {
719 unsigned int int_round = 0;
720 unsigned int do_round = 1;
721
722 int_round = *((unsigned int*)&(in[i]));
723
724 /* we don't round NaN and inf */
725 if ( (int_round & 0x7f800000) == 0x7f800000 ) {
726 do_round = 0;
727 }
728
729 /* perform round nearest tie even */
730 if ( do_round != 0 ) {
731 unsigned int fixup = (int_round >> 16) & 1;
732 int_round = int_round + 0x00007fff + fixup;
733 }
734
735 /* create the bf16 value by shifting out the lower 16bits */
736 int_round = int_round >> 16;
737
738 out[i] = (unsigned short)int_round;
739 }
740 }
741
742
libxsmm_convert_bf16_f32(const libxsmm_bfloat16 * in,float * out,unsigned int length)743 LIBXSMM_API void libxsmm_convert_bf16_f32(const libxsmm_bfloat16* in, float* out, unsigned int length) {
744 unsigned int i = 0;
745
746 /* up-convert is super simple */
747 for ( i = 0; i < length; ++i ) {
748 libxsmm_bfloat16_hp t;
749
750 t.i[1] = in[i];
751 t.i[0] = 0;
752 out[i] = t.f;
753 }
754 }
755
756