1 /****************************************************************************** 2 * Copyright (c) Intel Corporation - All rights reserved. * 3 * This file is part of the LIBXSMM library. * 4 * * 5 * For information on the license, see the LICENSE file. * 6 * Further information: https://github.com/hfp/libxsmm/ * 7 * SPDX-License-Identifier: BSD-3-Clause * 8 ******************************************************************************/ 9 /* Alexander Heinecke, Hans Pabst (Intel Corp.) 10 ******************************************************************************/ 11 #ifndef LIBXSMM_DNN_H 12 #define LIBXSMM_DNN_H 13 14 #include "libxsmm_typedefs.h" 15 16 typedef unsigned int libxsmm_dnn_err_t; 17 18 /** Define error and warning codes */ 19 #define LIBXSMM_DNN_SUCCESS 0 20 21 #define LIBXSMM_DNN_WARN_FALLBACK 90000 22 #define LIBXSMM_DNN_WARN_RNN_SUBOPTIMAL_N_BLOCKING 90001 23 #define LIBXSMM_DNN_WARN_RNN_SUBOPTIMAL_C_BLOCKING 90002 24 #define LIBXSMM_DNN_WARN_RNN_SUBOPTIMAL_K_BLOCKING 90003 25 #define LIBXSMM_DNN_WARN_FC_SUBOPTIMAL_N_BLOCKING 90004 26 #define LIBXSMM_DNN_WARN_FC_SUBOPTIMAL_C_BLOCKING 90005 27 #define LIBXSMM_DNN_WARN_FC_SUBOPTIMAL_K_BLOCKING 90006 28 29 #define LIBXSMM_DNN_ERR_GENERAL 100000 30 #define LIBXSMM_DNN_ERR_CREATE_HANDLE 100001 31 #define LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE 100002 32 #define LIBXSMM_DNN_ERR_INVALID_BLOCKING 100003 33 #define LIBXSMM_DNN_ERR_INVALID_HANDLE 100004 34 #define LIBXSMM_DNN_ERR_DATA_NOT_BOUND 100005 35 #define LIBXSMM_DNN_ERR_CREATE_TENSOR 100006 36 #define LIBXSMM_DNN_ERR_INVALID_TENSOR 100007 37 #define LIBXSMM_DNN_ERR_MISMATCH_TENSOR 100008 38 #define LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR 100009 39 #define LIBXSMM_DNN_ERR_INVALID_KIND 100010 40 #define LIBXSMM_DNN_ERR_INVALID_FORMAT_NCHW 100011 41 #define LIBXSMM_DNN_ERR_UNSUPPORTED_DST_FORMAT 100012 42 #define LIBXSMM_DNN_ERR_UNSUPPORTED_SRC_FORMAT 100013 43 #define LIBXSMM_DNN_ERR_INVALID_FORMAT_CONVOLVE 100014 44 #define LIBXSMM_DNN_ERR_INVALID_FORMAT_KCRS 100015 45 #define LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL 100016 46 #define LIBXSMM_DNN_ERR_CREATE_LAYOUT 100017 47 #define LIBXSMM_DNN_ERR_INVALID_LAYOUT 100018 48 #define LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH 100019 49 #define LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED 100020 50 #define LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE 100021 51 #define LIBXSMM_DNN_ERR_INVALID_ALGO 100022 52 #define LIBXSMM_DNN_ERR_INVALID_PADDING 100023 53 #define LIBXSMM_DNN_ERR_UNKNOWN_BIAS_TYPE 100024 54 #define LIBXSMM_DNN_ERR_MISMATCH_BIAS 100025 55 #define LIBXSMM_DNN_ERR_INVALID_HANDLE_BIAS 100026 56 #define LIBXSMM_DNN_ERR_TIME_STEPS_TOO_SMALL 100027 57 #define LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS 100028 58 #define LIBXSMM_DNN_ERR_NOT_IMPLEMENTED 100029 59 #define LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_ORDER 100030 60 #define LIBXSMM_DNN_ERR_FUSEDBN_UNSUPPORTED_FUSION 100031 61 #define LIBXSMM_DNN_ERR_INVALID_FORMAT_FUSEDBN 100032 62 #define LIBXSMM_DNN_ERR_UNSUPPORTED_POOLING 100033 63 #define LIBXSMM_DNN_ERR_INVALID_FORMAT_FC 100034 64 #define LIBXSMM_DNN_ERR_INVALID_RNN_TYPE 100035 65 #define LIBXSMM_DNN_ERR_RNN_INVALID_SEQ_LEN 100036 66 #define LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_ORDER 100037 67 #define LIBXSMM_DNN_ERR_FUSEDGN_UNSUPPORTED_FUSION 100038 68 #define LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION 100039 69 70 /** Kinds of supported compute flavor operations. */ 71 typedef enum libxsmm_dnn_compute_kind { 72 /** Forward path */ 73 LIBXSMM_DNN_COMPUTE_KIND_FWD, 74 /** Backward path */ 75 LIBXSMM_DNN_COMPUTE_KIND_BWD, 76 /** Updated weights. */ 77 LIBXSMM_DNN_COMPUTE_KIND_UPD, 78 /** Backward and weightupdate combined, useful for RNNs */ 79 LIBXSMM_DNN_COMPUTE_KIND_BWDUPD, 80 /** All routines, need for some init routines. */ 81 LIBXSMM_DNN_COMPUTE_KIND_ALL 82 } libxsmm_dnn_compute_kind; 83 84 /** these are some quantization definitions, not sure if we want to 85 move them into some main part of LIBXSMM */ 86 /* @TODO check position of these declarations and defines */ 87 typedef union LIBXSMM_RETARGETABLE libxsmm_intfloat { 88 unsigned int ui; 89 float f; 90 } libxsmm_intfloat; 91 92 /* F32 masking defines */ 93 #define LIBXSNN_DNN_MASK_SIGN_F32 0x80000000 94 #define LIBXSMM_DNN_MASK_EXP_F32 0x7f800000 95 #define LIBXSMM_DNN_MASK_MANT_F32 0x007fffff 96 #define LIBXSMM_DNN_MASK_ABS_F32 0x7fffffff 97 #define LIBXSMM_DNN_MASK_FULL_F32 0xffffffff 98 #define LIBXSMM_DNN_MANT_SZ_F32 23 99 #define LIBXSMM_DNN_SZ_F32 32 100 101 /* DFP16 masking defines */ 102 #define LIBXSMM_DNN_MANT_DFP16 15 103 #define LIXSMMM_DNN_RES_DFP16 libxsmm_sexp2_i8i(-(LIBXSMM_DNN_MANT_DFP16)) 104 105 /* Quantization Rounding Defines */ 106 #define LIBXSMM_DNN_QUANT_NO_ROUND 80000 107 #define LIBXSMM_DNN_QUANT_BIAS_ROUND 80001 108 #define LIBXSMM_DNN_QUANT_STOCH_ROUND 80002 109 #define LIBXSMM_DNN_QUANT_NEAREST_ROUND 80003 110 #define LIBXSMM_DNN_QUANT_FPHW_ROUND 80004 111 112 /** get string of error code */ 113 LIBXSMM_API const char* libxsmm_dnn_get_error(libxsmm_dnn_err_t code); 114 LIBXSMM_API size_t libxsmm_dnn_typesize(libxsmm_dnn_datatype datatype); 115 LIBXSMM_API size_t libxsmm_dnn_get_simd_width(libxsmm_dnn_datatype datatype); 116 117 /** some quantization helper functions, 118 @TODO need to be integrated better for all different ways of quantizations */ 119 LIBXSMM_API void libxsmm_dnn_quantize( float* in_buffer, short* out_buffer, int length, unsigned char add_shift, unsigned char* scf, int round_mode ); 120 LIBXSMM_API void libxsmm_dnn_quantize_act( float* in_buffer, short* out_buffer, unsigned int N, unsigned int C, unsigned int H, unsigned int W, unsigned int cblk_f32, unsigned int cblk_i16, unsigned int lp_blk, unsigned char add_shift, unsigned char* scf, int round_mode ); 121 LIBXSMM_API void libxsmm_dnn_quantize_fil( float* in_buffer, short* out_buffer, unsigned int K, unsigned int C, unsigned int R, unsigned int S, unsigned int cblk_f32, unsigned int cblk_i16, unsigned int kblk_f32, unsigned int kblk_i16, unsigned int lp_blk, unsigned char add_shift, unsigned char* scf, int round_mode ); 122 LIBXSMM_API void libxsmm_dnn_dequantize( short* in_buffer, float* out_buffer, int length, unsigned char scf ); 123 124 /** some BF16<->FP32 conversion functions 125 @TODO we need to find a final place for those */ 126 LIBXSMM_API void libxsmm_truncate_convert_f32_bf16(const float* in, libxsmm_bfloat16* out, unsigned int length); 127 LIBXSMM_API void libxsmm_rnaz_convert_fp32_bf16(const float* in, libxsmm_bfloat16* out, unsigned int len); 128 LIBXSMM_API void libxsmm_rne_convert_fp32_bf16(const float* in, libxsmm_bfloat16* out, unsigned int len); 129 LIBXSMM_API void libxsmm_convert_bf16_f32(const libxsmm_bfloat16* in, float* out, unsigned int length); 130 131 #endif /*LIBXSMM_DNN_H*/ 132 133