1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Hans Pabst, Alexander Heinecke (Intel Corp.)
10 ******************************************************************************/
11 #include <libxsmm_dnn.h>
12 #include "libxsmm_main.h"
13 
14 #if defined(LIBXSMM_OFFLOAD_TARGET)
15 # pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
16 #endif
17 #include <math.h>
18 #if defined(_OPENMP)
19 # include <omp.h>
20 #endif
21 #if defined(LIBXSMM_OFFLOAD_TARGET)
22 # pragma offload_attribute(pop)
23 #endif
24 
25 
libxsmm_dnn_init(int target_arch)26 LIBXSMM_API_INTERN void libxsmm_dnn_init(int target_arch)
27 {
28   LIBXSMM_UNUSED(target_arch);
29 }
30 
31 
libxsmm_dnn_finalize(void)32 LIBXSMM_API_INTERN void libxsmm_dnn_finalize(void)
33 {
34 }
35 
36 
libxsmm_dnn_get_feature_map_blocks(int C,int K,int * C_block,int * K_block,int * fm_lp_block,libxsmm_dnn_datatype datatype_in,libxsmm_dnn_datatype datatype_out)37 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_get_feature_map_blocks( int C, int K, int* C_block, int* K_block, int* fm_lp_block, libxsmm_dnn_datatype datatype_in, libxsmm_dnn_datatype datatype_out ) {
38   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
39   int ifmblock = 0;
40   int ofmblock = 0;
41   int lp_block = 0;
42   int tmp_max_c_block = 32;
43   int tmp_max_k_block = 32;
44   int tmp_block = 0;
45 
46   /* init libxsmm */
47   LIBXSMM_INIT
48 
49   /* C */
50   if (libxsmm_target_archid >= LIBXSMM_X86_AVX512_CORE) {
51     tmp_max_c_block = 64;
52   }
53   if ( C < tmp_max_c_block ) {
54     ifmblock = C;
55   } else {
56     for ( tmp_block = 1; tmp_block <= tmp_max_c_block; tmp_block *= 2 ) {
57       if ( C % tmp_block == 0 ) ifmblock = tmp_block;
58     }
59   }
60 
61   /* K */
62   if (libxsmm_target_archid >= LIBXSMM_X86_AVX512_CORE) {
63     tmp_max_k_block = 64;
64   }
65   if ( K < tmp_max_k_block ) {
66     ofmblock = K;
67   } else {
68     for ( tmp_block = 1; tmp_block <= tmp_max_k_block; tmp_block *= 2 ) {
69       if ( K % tmp_block == 0 ) ofmblock = tmp_block;
70     }
71   }
72 
73   /* when do we need VNNI format? */
74   if ( (datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) {
75     lp_block = 1;
76   } else if ( (datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) {
77     lp_block = 2;
78   } else if ( (datatype_in == LIBXSMM_DNN_DATATYPE_I16) && ((datatype_out == LIBXSMM_DNN_DATATYPE_I32) || (datatype_out == LIBXSMM_DNN_DATATYPE_F32)) ) {
79     lp_block = 2;
80   } else if (datatype_in == LIBXSMM_DNN_DATATYPE_I8) {
81     lp_block = 4;
82   } else {
83     status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
84     return status;
85   }
86 
87   *C_block = ifmblock;
88   *K_block = ofmblock;
89   *fm_lp_block = lp_block;
90 
91   return status;
92 }
93 
94 
libxsmm_dnn_get_error(libxsmm_dnn_err_t code)95 LIBXSMM_API const char* libxsmm_dnn_get_error(libxsmm_dnn_err_t code)
96 {
97   switch (code) {
98     case LIBXSMM_DNN_SUCCESS:
99       return "LIBXSMM DNN Success!";
100     case LIBXSMM_DNN_WARN_FALLBACK:
101       return "LIBXSMM DNN Warning: Falling back to naive code as target is currently not supported by LIBXSMM!";
102     case LIBXSMM_DNN_WARN_RNN_SUBOPTIMAL_N_BLOCKING:
103       return "LIBXSMM DNN Warning: RNN cell suboptimal minibatch blocking!";
104     case LIBXSMM_DNN_WARN_RNN_SUBOPTIMAL_C_BLOCKING:
105       return "LIBXSMM DNN Warning: RNN cell suboptimal input feature blocking!";
106     case LIBXSMM_DNN_WARN_RNN_SUBOPTIMAL_K_BLOCKING:
107       return "LIBXSMM DNN Warning: RNN cell suboptimal output feature blocking!";
108     case LIBXSMM_DNN_WARN_FC_SUBOPTIMAL_N_BLOCKING:
109       return "LIBXSMM DNN Warning: FC layer suboptimal minibatch blocking!";
110     case LIBXSMM_DNN_WARN_FC_SUBOPTIMAL_C_BLOCKING:
111       return "LIBXSMM DNN Warning: FC layer suboptimal input feature blocking!";
112     case LIBXSMM_DNN_WARN_FC_SUBOPTIMAL_K_BLOCKING:
113       return "LIBXSMM DNN Warning: FC layer suboptimal output feature blocking!";
114     case LIBXSMM_DNN_ERR_GENERAL:
115       return "LIBXSMM DNN Error: General error occurred!";
116     case LIBXSMM_DNN_ERR_CREATE_HANDLE:
117       return "LIBXSMM DNN Error: Handle creation failed!";
118     case LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE:
119       return "LIBXSMM DNN Error: Requested datatype is not available!";
120     case LIBXSMM_DNN_ERR_INVALID_BLOCKING:
121       return "LIBXSMM DNN Error: Requested Input/Output buffer size cannot be blocked!";
122     case LIBXSMM_DNN_ERR_INVALID_HANDLE:
123       return "LIBXSMM DNN Error: An invalid handle was provided!";
124     case LIBXSMM_DNN_ERR_DATA_NOT_BOUND:
125       return "LIBXSMM DNN Error: Not all required sources and destinations have been bound to convolution!";
126     case LIBXSMM_DNN_ERR_CREATE_TENSOR:
127       return "LIBXSMM DNN Error: Tensor creation failed!";
128     case LIBXSMM_DNN_ERR_INVALID_TENSOR:
129       return "LIBXSMM DNN Error: Invalid tensor was specified!";
130     case LIBXSMM_DNN_ERR_MISMATCH_TENSOR:
131       return "LIBXSMM DNN Error: Tensor doesn't match handle it should be bind to!";
132     case LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR:
133       return "LIBXSMM DNN Error: Invalid handle or tensor!";
134     case LIBXSMM_DNN_ERR_INVALID_KIND:
135       return "LIBXSMM DNN Error: Invalid convolution kind!";
136     case LIBXSMM_DNN_ERR_INVALID_FORMAT_NCHW:
137       return "LIBXSMM DNN Error: NCHW format is currently not natively supported by LIBXSMM!";
138     case LIBXSMM_DNN_ERR_UNSUPPORTED_DST_FORMAT:
139       return "LIBXSMM DNN Error: Unsupported destination format when copying data!";
140     case LIBXSMM_DNN_ERR_UNSUPPORTED_SRC_FORMAT:
141       return "LIBXSMM DNN Error: Unsupported source format when copying data!";
142     case LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE:
143       return "LIBXSMM DNN Error: Unsupported format when requesting a convolution!";
144     case LIBXSMM_DNN_ERR_INVALID_FORMAT_KCRS:
145       return "LIBXSMM DNN Error: KCRS format is currently not natively supported by LIBXSMM!";
146     case LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL:
147       return "LIBXSMM DNN Error: Invalid format was specified!";
148     case LIBXSMM_DNN_ERR_CREATE_LAYOUT:
149       return "LIBXSMM DNN Error: Layout creation failed!";
150     case LIBXSMM_DNN_ERR_INVALID_LAYOUT:
151       return "LIBXSMM DNN Error: Invalid layout was specified!";
152     case LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH:
153       return "LIBXSMM DNN Error: Unsupported architecture!";
154     case LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED:
155       return "LIBXSMM DNN Error: scratch binding failed as scratch was not allocated!";
156     case LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE:
157       return "LIBXSMM DNN Error: an unknown tensor type was provided!";
158     case LIBXSMM_DNN_ERR_INVALID_ALGO:
159       return "LIBXSMM DNN Error: Invalid algorithm was specified!";
160     case LIBXSMM_DNN_ERR_INVALID_PADDING:
161       return "LIBXSMM DNN Error: Invalid padding was specified!";
162     case LIBXSMM_DNN_ERR_TIME_STEPS_TOO_SMALL:
163       return "LIBXSMM DNN Error: time steps should be >= 2 for RNN/LSTM!";
164     case LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS:
165       return "LIBXSMM DNN Error: failed to create internal layout arrays!";
166     case LIBXSMM_DNN_ERR_NOT_IMPLEMENTED:
167       return "LIBXSMM DNN Error: the requested functionality is right now not implemented!";
168     case LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER:
169       return "LIBXSMM DNN Error: the requested order of fusion in batch norm is right now not implemented!";
170     case LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION:
171       return "LIBXSMM DNN Error: the requested fusion in batch norm is right now not implemented!";
172     case LIBXSMM_DNN_ERR_INVALID_FORMAT_FUSEDBN:
173       return "LIBXSMM DNN Error: Unsupported format when requesting a fused batch norm!";
174     case LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING:
175       return "LIBXSMM DNN Error: Unsupported pooling operations was requested!";
176     case LIBXSMM_DNN_ERR_INVALID_FORMAT_FC:
177       return "LIBXSMM DNN Error: Unsupported format when requesting a fullyconnected layer!";
178     case LIBXSMM_DNN_ERR_RNN_INVALID_SEQ_LEN:
179       return "LIBXSMM DNN Error: max sequence length is shorter than sequence length we attempt to set!";
180     case LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_ORDER:
181       return "LIBXSMM DNN Error: the requested order of fusion in group norm is right now not implemented!";
182     case LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_FUSION:
183       return "LIBXSMM DNN Error: the requested fusion in group norm is right now not implemented!";
184     case LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION:
185       return "LIBXSMM DNN Error: the requested fusion in fullyconnected is right now not implemented!";
186     default:
187       return "LIBXSMM DNN Error: Unknown error or warning occurred!";
188   }
189 }
190 
191 
libxsmm_dnn_typesize(libxsmm_dnn_datatype datatype)192 LIBXSMM_API size_t libxsmm_dnn_typesize(libxsmm_dnn_datatype datatype)
193 {
194   switch (datatype) {
195     case LIBXSMM_DNN_DATATYPE_F32: return 4;
196     case LIBXSMM_DNN_DATATYPE_I32: return 4;
197     case LIBXSMM_DNN_DATATYPE_BF16:return 2;
198     case LIBXSMM_DNN_DATATYPE_I16: return 2;
199     case LIBXSMM_DNN_DATATYPE_I8:  return 1;
200     /* no error expected as enumeration really arrives at an enum; compiler-checked */
201     default: return 1;
202   }
203 }
204 
205 
libxsmm_dnn_get_simd_width(libxsmm_dnn_datatype datatype)206 LIBXSMM_API size_t libxsmm_dnn_get_simd_width(libxsmm_dnn_datatype datatype)
207 {
208   size_t l_cl_width_bytes;
209 
210   /* init libxsmm */
211   LIBXSMM_INIT
212 
213   if ( libxsmm_target_archid == LIBXSMM_X86_GENERIC ) {
214     l_cl_width_bytes = libxsmm_dnn_typesize(datatype);
215   } else if ( libxsmm_target_archid == LIBXSMM_X86_SSE3 ||
216       libxsmm_target_archid == LIBXSMM_X86_SSE4 ) {
217     l_cl_width_bytes = 16;
218   } else if ( libxsmm_target_archid == LIBXSMM_X86_AVX2 ||
219       libxsmm_target_archid == LIBXSMM_X86_AVX ) {
220     l_cl_width_bytes = 32;
221   } else {
222     l_cl_width_bytes = 64;
223   }
224 
225   return l_cl_width_bytes/libxsmm_dnn_typesize(datatype);
226 }
227 
228 
libxsmm_internal_get_max(float * in_buffer,int length)229 LIBXSMM_API_INLINE float libxsmm_internal_get_max( float* in_buffer, int length ) {
230   float absmax_value = LIBXSMM_ABS(in_buffer[0]);
231   int i = 0;
232 #ifdef _OPENMP
233   LIBXSMM_OMP_VAR(i);
234 # pragma omp parallel private(i)
235   {
236     float my_absmax_value = absmax_value;
237 #   pragma omp for
238     for (i = 0; i < length; ++i ) {
239       if (LIBXSMM_ABS(in_buffer[i]) > my_absmax_value) {
240         my_absmax_value = LIBXSMM_ABS(in_buffer[i]);
241       }
242     }
243 #   pragma omp critical
244     {
245       if (my_absmax_value > absmax_value) {
246         absmax_value = my_absmax_value;
247       }
248     }
249   }
250 #else
251   for (i = 1; i < length; ++i ) {
252     if (LIBXSMM_ABS(in_buffer[i]) > absmax_value) {
253       absmax_value = LIBXSMM_ABS(in_buffer[i]);
254     }
255   }
256 #endif
257 
258   return absmax_value;
259 }
260 
261 
libxsmm_internal_get_max_exp(float * in_buffer,int length)262 LIBXSMM_API_INLINE unsigned char libxsmm_internal_get_max_exp( float* in_buffer, int length ) {
263   libxsmm_intfloat val_exp;
264   unsigned char max_exp = 0;
265 
266   /* bit-wise conversion to int */
267   val_exp.f = libxsmm_internal_get_max( in_buffer, length );
268   /* shift by mantissa to the right and convert to char */
269   max_exp = (unsigned char)((val_exp.ui & LIBXSMM_DNN_MASK_ABS_F32) >> LIBXSMM_DNN_MANT_SZ_F32);
270 
271   return max_exp;
272 }
273 
274 
libxsmm_internal_quantize_scalar_no_scf(float input,unsigned char max_exp,unsigned char add_shift,int round_mode)275 LIBXSMM_API_INLINE short libxsmm_internal_quantize_scalar_no_scf( float input, unsigned char max_exp, unsigned char add_shift, int round_mode ) {
276   libxsmm_intfloat value;
277   unsigned int qvalue = 0;
278   unsigned int mant = 0;
279   unsigned int sign = 0;
280   unsigned char rhs = 0;
281   unsigned char exp_off = 0;
282 
283   /* init libxsmm */
284   LIBXSMM_INIT
285 
286   /* in case of zero we don't need to do anything */
287   if (LIBXSMM_FEQ(input, 0)) {
288     qvalue = 0;
289   } else {
290     /* let's get a float copy to work on */
291     /* vinp = LIBXSMM_INTRINSICS_MM512_LOAD_PS( in_buffer[i] ); */
292     value.f = input;
293     /* let's compute the offset of the current exp at pos i from max offset, we need to mask the sign bit though */
294     /*__m512i vexp     = _mm512_cvtps_epi32(_mm512_getexp_ps (vinp));
295       __m512i vexp_off = _mm512_sub_epi32(maxexpf, vexp);*/
296     exp_off = (unsigned char)(max_exp - ((value.ui & LIBXSMM_DNN_MASK_ABS_F32) >> LIBXSMM_DNN_MANT_SZ_F32));
297     /* cut out mantissa and set leading bit */
298     /*__m512i mmask = _mm512_set1_epi32(LIBXSMM_DNN_MASK_MANT_F32);
299       __m512i vmant = _mm512_or_epi32(_mm512_set1_epi32(0x1 << LIBXSMM_DNN_MANT_SZ_F32), _mm512_and_epi32( _mm512_castps_si512( vinp ), mmask));*/
300     mant = ((0x1 << LIBXSMM_DNN_MANT_SZ_F32) | (value.ui & LIBXSMM_DNN_MASK_MANT_F32));
301     /* extract sign */
302     /* __mmask16 smask =  _mm512_cmplt_ps_mask (inp, _mm512_set1_ps(0)); */
303     sign = ((value.ui & LIBXSNN_DNN_MASK_SIGN_F32) >> (LIBXSMM_DNN_SZ_F32-1));
304     /* calculate rhs, be aware of the now explicit leading bit, @TODO add DFP8/4 */
305     rhs = (unsigned char)((LIBXSMM_DNN_MANT_SZ_F32+1) - LIBXSMM_DNN_MANT_DFP16 + exp_off + add_shift);
306     /* some safety, to generate 0 when we fall off quant region, @TODO issue a LIBXSMM Warning that we shifted out the entire mantissa */
307     if (rhs > (LIBXSMM_DNN_MANT_SZ_F32+1)) {
308       rhs = (LIBXSMM_DNN_MANT_SZ_F32+1);
309     }
310     /* finally shift the value into the region we need, this is now a 15-add_rhs bit number for the max value in in_buffer */
311     qvalue = (mant >> rhs);
312     /* handle sign, 2 complement */
313     if ( (sign > 0) && (qvalue > 0) ) {
314       qvalue = (~qvalue + 1);
315     }
316 
317     if (round_mode == LIBXSMM_DNN_QUANT_BIAS_ROUND) {
318       /* biased rounding towards next bigger number */
319       /* first let's determine in the original number if we need a bias rounding, @TODO need fix for F64 */
320       int bias_needed = (mant & (0x3 << (rhs-2)));
321       /* apply bias */
322       if (bias_needed > 0) {
323         qvalue++;
324       }
325     } else if (round_mode == LIBXSMM_DNN_QUANT_NEAREST_ROUND) {
326       int nearest_needed = (mant & (0x1 << (rhs-1)));
327       /* apply rounding */
328       if ((nearest_needed > 0) && (rhs > 1)) {
329         qvalue++;
330       }
331     } else if (round_mode == LIBXSMM_DNN_QUANT_STOCH_ROUND) {
332       /* stochastic rounding, as implemented in the IBM paper from 2015, @TODO, fix F64 and DFP8 */
333       const float eps = LIXSMMM_DNN_RES_DFP16;
334       /* coverity[dont_call] */
335       const float r = (float)rand();
336       libxsmm_intfloat fvalue;
337       float p, q;
338       /* masking all bits which will be shifted out */
339       fvalue.ui = value.ui & ((LIBXSMM_DNN_MASK_FULL_F32) << rhs);
340       /* drawing a random number */
341       p = r/((float)RAND_MAX);
342       q = (input - fvalue.f)/eps;
343       /* apply rounding if needed */
344       if ((p + q) > 0.5f) {
345         ++qvalue;
346       }
347     } else {
348       /* do nothing about rounding, just chop */
349     }
350   }
351 
352   return (short)qvalue;
353 }
354 
355 
356 /* @TODO make this routine aware of any int type */
libxsmm_dnn_quantize(float * in_buffer,short * out_buffer,int length,unsigned char add_shift,unsigned char * scf,int round_mode)357 LIBXSMM_API void libxsmm_dnn_quantize( float* in_buffer, short* out_buffer, int length, unsigned char add_shift, unsigned char* scf, int round_mode ) {
358   int i = 0;
359 
360   /* init libxsmm */
361   LIBXSMM_INIT
362 
363   /* in case we are using FP-Mul based quantization we use a different path for now
364      @TODO let's unify the paths by using the similar vectorization for both */
365   if ( round_mode == LIBXSMM_DNN_QUANT_FPHW_ROUND ) {
366     const float max_value = libxsmm_internal_get_max( in_buffer, length );
367     int maxexp = 0;
368     /* take return value of LIBXSMM_FREXPF to mute static analysis issue */
369     float scfq = LIBXSMM_FREXPF(max_value, &maxexp);
370     maxexp -= (15/*LIBXSMM_DNN_MANT_DFP16?*/ - add_shift);
371     scfq = libxsmm_sexp2_i8i(-maxexp);
372 
373 #if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH)
374     if ( length % 16 == 0 ) {
375       __m512 vscfq = _mm512_set1_ps(scfq);
376 #ifdef _OPENMP
377 #     pragma omp parallel for private(i)
378 #endif
379       for (i = 0; i < length; i+=16 ) {
380         _mm256_stream_si256( (__m256i *)&(out_buffer[i]), LIBXSMM_INTRINSICS_MM512_QUANTIZE_NEAR_PS_EPI16( &(in_buffer[i]), vscfq ) );
381       }
382     } else {
383 #endif
384 #ifdef _OPENMP
385 #     pragma omp parallel for private(i)
386 #endif
387       for (i = 0; i < length; ++i ) {
388         out_buffer[i] = (short)LIBXSMM_ROUNDF(in_buffer[i] * scfq);
389       }
390 #if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH)
391     }
392 #endif
393     /* @TODO, we need to potentially fix this unsigned char problem */
394 #if !defined(NDEBUG) /* library code is expected to be mute */
395     if (maxexp > 0) {
396       fprintf(stderr, "error quant fil\n");
397     }
398 #endif
399     *scf = (unsigned char)(-maxexp);
400   } else {
401     /* get max exponent */
402     unsigned char max_exp = libxsmm_internal_get_max_exp( in_buffer, length );
403 
404     /* if we go for stochastic rounding, let's initialize random seed */
405     if ( round_mode == LIBXSMM_DNN_QUANT_STOCH_ROUND ) {
406       srand(libxsmm_timer_tick() % ((unsigned int)-1));
407     }
408 
409 #ifdef _OPENMP
410 #   pragma omp parallel for private(i)
411 #endif
412     for (i = 0; i < length; ++i ) {
413       out_buffer[i] = libxsmm_internal_quantize_scalar_no_scf( in_buffer[i], max_exp, add_shift, round_mode );
414     }
415 
416     *scf = (unsigned char)(14 - add_shift - (max_exp - 127));
417   }
418 }
419 
420 
libxsmm_dnn_quantize_act(float * in_buffer,short * out_buffer,unsigned int N,unsigned int C,unsigned int H,unsigned int W,unsigned int cblk_f32,unsigned int cblk_i16,unsigned int lp_blk,unsigned char add_shift,unsigned char * scf,int round_mode)421 LIBXSMM_API void libxsmm_dnn_quantize_act( float* in_buffer, short* out_buffer, unsigned int N, unsigned int C, unsigned int H, unsigned int W, unsigned int cblk_f32, unsigned int cblk_i16, unsigned int lp_blk, unsigned char add_shift, unsigned char* scf, int round_mode ) {
422   LIBXSMM_VLA_DECL(5, const float, in,  in_buffer,  C/cblk_f32, H, W, cblk_f32);
423   LIBXSMM_VLA_DECL(6, short, out, out_buffer, C/(cblk_i16*lp_blk), H, W, cblk_i16, lp_blk);
424   const unsigned int cblk = C/(cblk_i16*lp_blk);
425   int i1 = 0, i2 = 0, i3 = 0, i4 = 0, i5, i6;
426 
427   /* init libxsmm */
428   LIBXSMM_INIT
429 
430   /* some quick and dirty checks */
431   assert((C % cblk_f32) == 0);
432   assert((C % cblk_i16) == 0);
433 
434   /* in case we are using FP-Mul based quantization we use a different path for now
435      @TODO let's unify the paths by using the similar vectorization for both */
436   if ( round_mode == LIBXSMM_DNN_QUANT_FPHW_ROUND ) {
437     const float max_value = libxsmm_internal_get_max( in_buffer, N*C*H*W );
438     int maxexp = 0;
439     /* take return value of LIBXSMM_FREXPF to mute static analysis issue */
440     float scfq = LIBXSMM_FREXPF(max_value, &maxexp);
441     maxexp -= (15/*LIBXSMM_DNN_MANT_DFP16?*/ - add_shift);
442     scfq = libxsmm_sexp2_i8i(-maxexp);
443 
444 #if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH)
445     if ( (cblk_f32 == 16) && (cblk_i16*lp_blk == 16) ) {
446       __m512 vscfq = _mm512_set1_ps(scfq);
447 #ifdef _OPENMP
448       LIBXSMM_OMP_VAR(i1);
449 #     pragma omp parallel for private(i1)
450 #endif
451       for (i1 = 0; i1 < (int)(N*C*H*W); i1 += 16 ) {
452         _mm256_stream_si256( (__m256i *)&(out_buffer[i1]), LIBXSMM_INTRINSICS_MM512_QUANTIZE_NEAR_PS_EPI16( &(in_buffer[i1]), vscfq ) );
453       }
454     } else {
455 #endif
456 #ifdef _OPENMP
457       LIBXSMM_OMP_VAR(i1); LIBXSMM_OMP_VAR(i2); LIBXSMM_OMP_VAR(i3); LIBXSMM_OMP_VAR(i4); LIBXSMM_OMP_VAR(i5); LIBXSMM_OMP_VAR(i6);
458 #     pragma omp parallel for private(i1, i2, i3, i4, i5, i6) LIBXSMM_OPENMP_COLLAPSE(4)
459 #endif
460       for (i1 = 0; i1 < (int)N; ++i1 ) {
461         for (i2 = 0; i2 < (int)cblk; ++i2 ) {
462           for (i3 = 0; i3 < (int)H; ++i3 ) {
463             for (i4 = 0; i4 < (int)W; ++i4 ) {
464               for (i5 = 0; i5 < (int)cblk_i16; ++i5 ) {
465                 for (i6 = 0; i6 < (int)lp_blk; ++i6 ) {
466                   const int fi1 = i1;
467                   const int fi2 = ((i2*cblk_i16*lp_blk)+(i5*lp_blk)+i6)/cblk_f32;
468                   const int fi3 = i3;
469                   const int fi4 = i4;
470                   const int fi5 = ((i2*cblk_i16*lp_blk)+(i5*lp_blk)+i6)%cblk_f32;
471                   LIBXSMM_VLA_ACCESS(6, out, i1, i2, i3, i4, i5, i6, cblk, H, W, cblk_i16, lp_blk) = (short)LIBXSMM_ROUNDF(
472                   LIBXSMM_VLA_ACCESS(5, in, fi1, fi2, fi3, fi4, fi5, C / cblk_f32, H, W, cblk_f32) * scfq);
473                 }
474               }
475             }
476           }
477         }
478       }
479 #if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH)
480     }
481 #endif
482     /* @TODO, we need to potentially fix this unsigned char problem */
483 #if !defined(NDEBUG) /* library code is expected to be mute */
484     if (maxexp > 0) {
485       fprintf(stderr, "error quant act\n");
486     }
487 #endif
488     *scf = (unsigned char)(-maxexp);
489   } else {
490     /* get max exponent */
491     unsigned char max_exp = libxsmm_internal_get_max_exp( in_buffer, N*C*H*W );
492 
493     /* if we go for stochastic rounding, let's initialize random seed */
494     if ( round_mode == LIBXSMM_DNN_QUANT_STOCH_ROUND ) {
495       srand(libxsmm_timer_tick() % ((unsigned int)-1));
496     }
497 
498 #ifdef _OPENMP
499 #   pragma omp parallel for private(i1, i2, i3, i4, i5, i6) LIBXSMM_OPENMP_COLLAPSE(4)
500 #endif
501     for (i1 = 0; i1 < (int)N; ++i1 ) {
502       for (i2 = 0; i2 < (int)cblk; ++i2 ) {
503         for (i3 = 0; i3 < (int)H; ++i3 ) {
504           for (i4 = 0; i4 < (int)W; ++i4 ) {
505             for (i5 = 0; i5 < (int)cblk_i16; ++i5 ) {
506               for (i6 = 0; i6 < (int)lp_blk; ++i6 ) {
507                 const int fi1 = i1;
508                 const int fi2 = ((i2*cblk_i16*lp_blk)+(i5*lp_blk)+i6)/cblk_f32;
509                 const int fi3 = i3;
510                 const int fi4 = i4;
511                 const int fi5 = ((i2*cblk_i16*lp_blk)+(i5*lp_blk)+i6)%cblk_f32;
512                 LIBXSMM_VLA_ACCESS(6, out, i1, i2, i3, i4, i5, i6, cblk, H, W, cblk_i16, lp_blk) = libxsmm_internal_quantize_scalar_no_scf(
513                 LIBXSMM_VLA_ACCESS(5, in, fi1, fi2, fi3, fi4, fi5, C / cblk_f32, H, W, cblk_f32), max_exp, add_shift, round_mode);
514               }
515             }
516           }
517         }
518       }
519     }
520 
521     *scf = (unsigned char)(14 - add_shift - (max_exp - 127));
522   }
523 }
524 
525 
libxsmm_dnn_quantize_fil(float * in_buffer,short * out_buffer,unsigned int K,unsigned int C,unsigned int R,unsigned int S,unsigned int cblk_f32,unsigned int cblk_i16,unsigned int kblk_f32,unsigned int kblk_i16,unsigned int lp_blk,unsigned char add_shift,unsigned char * scf,int round_mode)526 LIBXSMM_API void libxsmm_dnn_quantize_fil( float* in_buffer, short* out_buffer, unsigned int K, unsigned int C, unsigned int R, unsigned int S, unsigned int cblk_f32, unsigned int cblk_i16, unsigned int kblk_f32, unsigned int kblk_i16, unsigned int lp_blk, unsigned char add_shift, unsigned char* scf, int round_mode ) {
527   LIBXSMM_VLA_DECL(6, const float, in,  in_buffer,  C/cblk_f32, R, S, cblk_f32, kblk_f32);
528   LIBXSMM_VLA_DECL(7, short, out, out_buffer, C/(cblk_i16*lp_blk), R, S, cblk_i16, kblk_i16, lp_blk);
529   unsigned int cblk = C/(cblk_i16*lp_blk);
530   unsigned int kblk = K/kblk_i16;
531   int i1 = 0, i2 = 0, i3 = 0, i4 = 0, i5, i6, i7;
532 
533   /* some quick and dirty checks */
534   assert((C % cblk_f32) == 0);
535   assert((C % (cblk_i16*lp_blk)) == 0);
536   assert((K % kblk_f32) == 0);
537   assert((K % kblk_i16) == 0);
538   assert((lp_blk % 2) == 0);
539 
540   /* init libxsmm */
541   LIBXSMM_INIT
542 
543   /* in case we are using FP-Mul based quantization we use a different path for now
544      @TODO let's unify the paths by using the similar vectorization for both */
545   if ( round_mode == LIBXSMM_DNN_QUANT_FPHW_ROUND ) {
546     const float max_value = libxsmm_internal_get_max( in_buffer, K*C*R*S );
547     int maxexp = 0;
548     /* take return value of LIBXSMM_FREXPF to mute static analysis issue */
549     float scfq = LIBXSMM_FREXPF(max_value, &maxexp);
550     maxexp -= (15/*LIBXSMM_DNN_MANT_DFP16?*/ - add_shift);
551     scfq = libxsmm_sexp2_i8i(-maxexp);
552 
553 #if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH)
554     if ( (kblk_f32 == 16) && (cblk_f32 == 16) && (kblk_i16 == 16) && (cblk_i16*lp_blk == 16) ) {
555       const __m512 vscfq = _mm512_set1_ps(scfq);
556       const __m512i permute_compact_idx = _mm512_set_epi32(15,14,13,12,7,6,5,4,11,10,9,8,3,2,1,0);
557 #ifdef _OPENMP
558 #     pragma omp parallel for private(i1, i2, i3, i4, i5) LIBXSMM_OPENMP_COLLAPSE(4)
559 #endif
560       for (i1 = 0; i1 < (int)kblk; ++i1 ) {
561         for (i2 = 0; i2 < (int)cblk; ++i2 ) {
562           for (i3 = 0; i3 < (int)R; ++i3 ) {
563             for (i4 = 0; i4 < (int)S; ++i4 ) {
564               for (i5 = 0; i5 < 16; i5+=2 ) {
565                 __m256i even_ch = LIBXSMM_INTRINSICS_MM512_QUANTIZE_NEAR_PS_EPI16(
566                   &LIBXSMM_VLA_ACCESS(6, in, i1, i2, i3, i4, i5 + 0, 0, C / cblk_f32, R, S, cblk_f32, kblk_f32), vscfq);
567                 __m256i odd_ch  = LIBXSMM_INTRINSICS_MM512_QUANTIZE_NEAR_PS_EPI16(
568                   &LIBXSMM_VLA_ACCESS(6, in, i1, i2, i3, i4, i5 + 1, 0, C / cblk_f32, R, S, cblk_f32, kblk_f32), vscfq);
569                 __m256i compressed_lo = _mm256_unpacklo_epi16(even_ch, odd_ch);
570                 __m256i compressed_hi = _mm256_unpackhi_epi16(even_ch, odd_ch);
571                 __m512i compact =  _mm512_inserti64x4( _mm512_setzero_si512(), compressed_lo, 0);
572                 compact = _mm512_inserti64x4(compact, compressed_hi, 1);
573                 compact = _mm512_permutexvar_epi32(permute_compact_idx, compact);
574                 LIBXSMM_INTRINSICS_MM512_STREAM_SI512(
575                   (void*)&LIBXSMM_VLA_ACCESS(7, out, i1, i2, i3, i4, i5 / 2, 0, 0, cblk, R, S, cblk_i16, kblk_i16, lp_blk),
576                   compact);
577               }
578             }
579           }
580         }
581       }
582     } else {
583 #endif
584 #ifdef _OPENMP
585       LIBXSMM_OMP_VAR(i1); LIBXSMM_OMP_VAR(i2); LIBXSMM_OMP_VAR(i3); LIBXSMM_OMP_VAR(i4); LIBXSMM_OMP_VAR(i5); LIBXSMM_OMP_VAR(i6); LIBXSMM_OMP_VAR(i7);
586 #     pragma omp parallel for private(i1, i2, i3, i4, i5, i6, i7) LIBXSMM_OPENMP_COLLAPSE(4)
587 #endif
588       for (i1 = 0; i1 < (int)kblk; ++i1 ) {
589         for (i2 = 0; i2 < (int)cblk; ++i2 ) {
590           for (i3 = 0; i3 < (int)R; ++i3 ) {
591             for (i4 = 0; i4 < (int)S; ++i4 ) {
592               for (i5 = 0; i5 < (int)cblk_i16; ++i5 ) {
593                 for (i6 = 0; i6 < (int)kblk_i16; ++i6 ) {
594                   for (i7 = 0; i7 < (int)lp_blk; ++i7 ) {
595                     const int fi1 = ((i1*kblk_i16)+i6)/kblk_f32;
596                     const int fi2 = ((i2*cblk_i16*lp_blk)+(i5*lp_blk)+i7)/cblk_f32;
597                     const int fi3 = i3;
598                     const int fi4 = i4;
599                     const int fi5 = ((i2*cblk_i16*lp_blk)+(i5*lp_blk)+i7)%cblk_f32;
600                     const int fi6 = ((i1*kblk_i16)+i6)%kblk_f32;
601                     LIBXSMM_VLA_ACCESS(7, out, i1, i2, i3, i4, i5, i6, i7, cblk, R, S, cblk_i16, kblk_i16, lp_blk) = (short)LIBXSMM_ROUNDF(
602                     LIBXSMM_VLA_ACCESS(6, in, fi1, fi2, fi3, fi4, fi5, fi6, C / cblk_f32, R, S, cblk_f32, kblk_f32) * scfq);
603                   }
604                 }
605               }
606             }
607           }
608         }
609       }
610 #if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH)
611     }
612 #endif
613     /* @TODO, we need to potentially fix this unsigned char problem */
614 #if !defined(NDEBUG) /* library code is expected to be mute */
615     if (maxexp > 0) {
616       fprintf(stderr, "error quant fil\n");
617     }
618 #endif
619     *scf = (unsigned char)(-maxexp);
620   } else {
621     /* get max exponent */
622     unsigned char max_exp = libxsmm_internal_get_max_exp( in_buffer, K*C*R*S );
623 
624     /* if we go for stochastic rounding, let's initialize random seed */
625     if ( round_mode == LIBXSMM_DNN_QUANT_STOCH_ROUND ) {
626       srand(libxsmm_timer_tick() % ((unsigned int)-1));
627     }
628 
629 #ifdef _OPENMP
630 #   pragma omp parallel for private(i1, i2, i3, i4, i5, i6, i7) LIBXSMM_OPENMP_COLLAPSE(4)
631 #endif
632     for (i1 = 0; i1 < (int)kblk; ++i1 ) {
633       for (i2 = 0; i2 < (int)cblk; ++i2 ) {
634         for (i3 = 0; i3 < (int)R; ++i3 ) {
635           for (i4 = 0; i4 < (int)S; ++i4 ) {
636             for (i5 = 0; i5 < (int)cblk_i16; ++i5 ) {
637               for (i6 = 0; i6 < (int)kblk_i16; ++i6 ) {
638                 for (i7 = 0; i7 < (int)lp_blk; ++i7 ) {
639                   const int fi1 = ((i1*kblk_i16)+i6)/kblk_f32;
640                   const int fi2 = ((i2*cblk_i16*lp_blk)+(i5*lp_blk)+i7)/cblk_f32;
641                   const int fi3 = i3;
642                   const int fi4 = i4;
643                   const int fi5 = ((i2*cblk_i16*lp_blk)+(i5*lp_blk)+i7)%cblk_f32;
644                   const int fi6 = ((i1*kblk_i16)+i6)%kblk_f32;
645                   LIBXSMM_VLA_ACCESS(7, out, i1, i2, i3, i4, i5, i6, i7, cblk, R, S, cblk_i16, kblk_i16, lp_blk) = libxsmm_internal_quantize_scalar_no_scf(
646                   LIBXSMM_VLA_ACCESS(6, in, fi1, fi2, fi3, fi4, fi5, fi6, C / cblk_f32, R, S, cblk_f32, kblk_f32), max_exp, add_shift, round_mode);
647                 }
648               }
649             }
650           }
651         }
652       }
653     }
654 
655     *scf = (unsigned char)(14 - add_shift - (max_exp - 127));
656   }
657 }
658 
659 
libxsmm_dnn_dequantize(short * in_buffer,float * out_buffer,int length,unsigned char scf)660 LIBXSMM_API void libxsmm_dnn_dequantize( short* in_buffer, float* out_buffer, int length, unsigned char scf ) {
661   const float val_exp = libxsmm_sexp2_i8i(-scf);
662   int i = 0;
663 
664 #ifdef _OPENMP
665 # pragma omp parallel for private(i)
666 #endif
667   for ( i = 0; i < length; ++i ) {
668     out_buffer[i] = ((float)in_buffer[i])*val_exp;
669   }
670 }
671 
672 
libxsmm_truncate_convert_f32_bf16(const float * in,libxsmm_bfloat16 * out,unsigned int length)673 LIBXSMM_API void libxsmm_truncate_convert_f32_bf16(const float* in, libxsmm_bfloat16* out, unsigned int length) {
674   unsigned int i = 0;
675 
676   /* truncate buffer to bf16 */
677   for ( i = 0; i < length; ++i ) {
678     libxsmm_bfloat16_hp t;
679 
680     t.f = in[i];
681     out[i] = t.i[1];
682   }
683 }
684 
685 
libxsmm_rnaz_convert_fp32_bf16(const float * in,libxsmm_bfloat16 * out,unsigned int len)686 LIBXSMM_API void libxsmm_rnaz_convert_fp32_bf16(const float* in, libxsmm_bfloat16* out, unsigned int len) {
687   unsigned int i = 0;
688 
689   /* truncate buffer to bf16 */
690   for ( i = 0; i < len; ++i ) {
691     unsigned int int_round = 0;
692     unsigned int do_round = 1;
693 
694     int_round = *((unsigned int*)&(in[i]));
695 
696     /* we don't round NaN and inf */
697     if ( (int_round & 0x7f800000) == 0x7f800000 ) {
698       do_round = 0;
699     }
700 
701     /* perform round nearest tie away from zero */
702     if ( do_round != 0 ) {
703       int_round = int_round + 0x00008000;
704     }
705 
706     /* create the bf16 value by shifting out the lower 16bits */
707     int_round = int_round >> 16;
708 
709     out[i] = (libxsmm_bfloat16)int_round;
710   }
711 }
712 
713 
libxsmm_rne_convert_fp32_bf16(const float * in,libxsmm_bfloat16 * out,unsigned int len)714 LIBXSMM_API void libxsmm_rne_convert_fp32_bf16(const float* in, libxsmm_bfloat16* out, unsigned int len) {
715   unsigned int i = 0;
716 
717   /* truncate buffer to bf16 */
718   for ( i = 0; i < len; ++i ) {
719     unsigned int int_round = 0;
720     unsigned int do_round = 1;
721 
722     int_round = *((unsigned int*)&(in[i]));
723 
724     /* we don't round NaN and inf */
725     if ( (int_round & 0x7f800000) == 0x7f800000 ) {
726       do_round = 0;
727     }
728 
729     /* perform round nearest tie even */
730     if ( do_round != 0 ) {
731       unsigned int fixup = (int_round >> 16) & 1;
732       int_round = int_round + 0x00007fff + fixup;
733     }
734 
735     /* create the bf16 value by shifting out the lower 16bits */
736     int_round = int_round >> 16;
737 
738     out[i] = (unsigned short)int_round;
739   }
740 }
741 
742 
libxsmm_convert_bf16_f32(const libxsmm_bfloat16 * in,float * out,unsigned int length)743 LIBXSMM_API void libxsmm_convert_bf16_f32(const libxsmm_bfloat16* in, float* out, unsigned int length) {
744   unsigned int i = 0;
745 
746   /* up-convert is super simple */
747   for ( i = 0; i < length; ++i ) {
748     libxsmm_bfloat16_hp t;
749 
750     t.i[1] = in[i];
751     t.i[0] = 0;
752     out[i] = t.f;
753   }
754 }
755 
756