1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Alexander Heinecke, Hans Pabst (Intel Corp.)
10 ******************************************************************************/
11 #ifndef LIBXSMM_DNN_H
12 #define LIBXSMM_DNN_H
13 
14 #include "libxsmm_typedefs.h"
15 
16 typedef unsigned int libxsmm_dnn_err_t;
17 
18 /** Define error and warning codes */
19 #define LIBXSMM_DNN_SUCCESS                             0
20 
21 #define LIBXSMM_DNN_WARN_FALLBACK                   90000
22 #define LIBXSMM_DNN_WARN_RNN_SUBOPTIMAL_N_BLOCKING  90001
23 #define LIBXSMM_DNN_WARN_RNN_SUBOPTIMAL_C_BLOCKING  90002
24 #define LIBXSMM_DNN_WARN_RNN_SUBOPTIMAL_K_BLOCKING  90003
25 #define LIBXSMM_DNN_WARN_FC_SUBOPTIMAL_N_BLOCKING   90004
26 #define LIBXSMM_DNN_WARN_FC_SUBOPTIMAL_C_BLOCKING   90005
27 #define LIBXSMM_DNN_WARN_FC_SUBOPTIMAL_K_BLOCKING   90006
28 
29 #define LIBXSMM_DNN_ERR_GENERAL                    100000
30 #define LIBXSMM_DNN_ERR_CREATE_HANDLE              100001
31 #define LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE       100002
32 #define LIBXSMM_DNN_ERR_INVALID_BLOCKING           100003
33 #define LIBXSMM_DNN_ERR_INVALID_HANDLE             100004
34 #define LIBXSMM_DNN_ERR_DATA_NOT_BOUND             100005
35 #define LIBXSMM_DNN_ERR_CREATE_TENSOR              100006
36 #define LIBXSMM_DNN_ERR_INVALID_TENSOR             100007
37 #define LIBXSMM_DNN_ERR_MISMATCH_TENSOR            100008
38 #define LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR      100009
39 #define LIBXSMM_DNN_ERR_INVALID_KIND               100010
40 #define LIBXSMM_DNN_ERR_INVALID_FORMAT_NCHW        100011
41 #define LIBXSMM_DNN_ERR_UNSUPPORTED_DST_FORMAT     100012
42 #define LIBXSMM_DNN_ERR_UNSUPPORTED_SRC_FORMAT     100013
43 #define LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE    100014
44 #define LIBXSMM_DNN_ERR_INVALID_FORMAT_KCRS        100015
45 #define LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL     100016
46 #define LIBXSMM_DNN_ERR_CREATE_LAYOUT              100017
47 #define LIBXSMM_DNN_ERR_INVALID_LAYOUT             100018
48 #define LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH           100019
49 #define LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED        100020
50 #define LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE        100021
51 #define LIBXSMM_DNN_ERR_INVALID_ALGO               100022
52 #define LIBXSMM_DNN_ERR_INVALID_PADDING            100023
53 #define LIBXSMM_DNN_ERR_UNKNOWN_BIAS_TYPE          100024
54 #define LIBXSMM_DNN_ERR_MISMATCH_BIAS              100025
55 #define LIBXSMM_DNN_ERR_INVALID_HANDLE_BIAS        100026
56 #define LIBXSMM_DNN_ERR_TIME_STEPS_TOO_SMALL       100027
57 #define LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS       100028
58 #define LIBXSMM_DNN_ERR_NOT_IMPLEMENTED            100029
59 #define LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER  100030
60 #define LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION 100031
61 #define LIBXSMM_DNN_ERR_INVALID_FORMAT_FUSEDBN     100032
62 #define LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING        100033
63 #define LIBXSMM_DNN_ERR_INVALID_FORMAT_FC          100034
64 #define LIBXSMM_DNN_ERR_INVALID_RNN_TYPE           100035
65 #define LIBXSMM_DNN_ERR_RNN_INVALID_SEQ_LEN        100036
66 #define LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_ORDER  100037
67 #define LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_FUSION 100038
68 #define LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION      100039
69 
70 /** Kinds of supported compute flavor operations. */
71 typedef enum libxsmm_dnn_compute_kind {
72   /** Forward path */
73   LIBXSMM_DNN_COMPUTE_KIND_FWD,
74   /** Backward path */
75   LIBXSMM_DNN_COMPUTE_KIND_BWD,
76   /** Updated weights. */
77   LIBXSMM_DNN_COMPUTE_KIND_UPD,
78   /** Backward and weightupdate combined, useful for RNNs */
79   LIBXSMM_DNN_COMPUTE_KIND_BWDUPD,
80   /** All routines, need for some init routines. */
81   LIBXSMM_DNN_COMPUTE_KIND_ALL
82 } libxsmm_dnn_compute_kind;
83 
84 /** these are some quantization definitions, not sure if we want to
85     move them into some main part of LIBXSMM */
86 /* @TODO check position of these declarations and defines */
87 typedef union LIBXSMM_RETARGETABLE libxsmm_intfloat {
88   unsigned int ui;
89   float f;
90 } libxsmm_intfloat;
91 
92 /* F32 masking defines */
93 #define LIBXSNN_DNN_MASK_SIGN_F32      0x80000000
94 #define LIBXSMM_DNN_MASK_EXP_F32       0x7f800000
95 #define LIBXSMM_DNN_MASK_MANT_F32      0x007fffff
96 #define LIBXSMM_DNN_MASK_ABS_F32       0x7fffffff
97 #define LIBXSMM_DNN_MASK_FULL_F32      0xffffffff
98 #define LIBXSMM_DNN_MANT_SZ_F32        23
99 #define LIBXSMM_DNN_SZ_F32             32
100 
101 /* DFP16 masking defines */
102 #define LIBXSMM_DNN_MANT_DFP16         15
103 #define LIXSMMM_DNN_RES_DFP16          libxsmm_sexp2_i8i(-(LIBXSMM_DNN_MANT_DFP16))
104 
105 /* Quantization Rounding Defines */
106 #define LIBXSMM_DNN_QUANT_NO_ROUND       80000
107 #define LIBXSMM_DNN_QUANT_BIAS_ROUND     80001
108 #define LIBXSMM_DNN_QUANT_STOCH_ROUND    80002
109 #define LIBXSMM_DNN_QUANT_NEAREST_ROUND  80003
110 #define LIBXSMM_DNN_QUANT_FPHW_ROUND     80004
111 
112 /** get string of error code */
113 LIBXSMM_API const char* libxsmm_dnn_get_error(libxsmm_dnn_err_t code);
114 LIBXSMM_API size_t libxsmm_dnn_typesize(libxsmm_dnn_datatype datatype);
115 LIBXSMM_API size_t libxsmm_dnn_get_simd_width(libxsmm_dnn_datatype datatype);
116 
117 /** some quantization helper functions,
118     @TODO need to be integrated better for all different ways of quantizations */
119 LIBXSMM_API void libxsmm_dnn_quantize( float* in_buffer, short* out_buffer, int length, unsigned char add_shift, unsigned char* scf, int round_mode );
120 LIBXSMM_API void libxsmm_dnn_quantize_act( float* in_buffer, short* out_buffer, unsigned int N, unsigned int C, unsigned int H, unsigned int W, unsigned int cblk_f32, unsigned int cblk_i16, unsigned int lp_blk, unsigned char add_shift, unsigned char* scf, int round_mode );
121 LIBXSMM_API void libxsmm_dnn_quantize_fil( float* in_buffer, short* out_buffer, unsigned int K, unsigned int C, unsigned int R, unsigned int S, unsigned int cblk_f32, unsigned int cblk_i16, unsigned int kblk_f32, unsigned int kblk_i16, unsigned int lp_blk, unsigned char add_shift, unsigned char* scf, int round_mode );
122 LIBXSMM_API void libxsmm_dnn_dequantize( short* in_buffer, float* out_buffer, int length, unsigned char scf );
123 
124 /** some BF16<->FP32 conversion functions
125     @TODO we need to find a final place for those */
126 LIBXSMM_API void libxsmm_truncate_convert_f32_bf16(const float* in, libxsmm_bfloat16* out, unsigned int length);
127 LIBXSMM_API void libxsmm_rnaz_convert_fp32_bf16(const float* in, libxsmm_bfloat16* out, unsigned int len);
128 LIBXSMM_API void libxsmm_rne_convert_fp32_bf16(const float* in, libxsmm_bfloat16* out, unsigned int len);
129 LIBXSMM_API void libxsmm_convert_bf16_f32(const libxsmm_bfloat16* in, float* out, unsigned int length);
130 
131 #endif /*LIBXSMM_DNN_H*/
132 
133