1 /****************************************************************************** 2 * Copyright (c) Intel Corporation - All rights reserved. * 3 * This file is part of the LIBXSMM library. * 4 * * 5 * For information on the license, see the LICENSE file. * 6 * Further information: https://github.com/hfp/libxsmm/ * 7 * SPDX-License-Identifier: BSD-3-Clause * 8 ******************************************************************************/ 9 /* Hans Pabst (Intel Corp.) 10 ******************************************************************************/ 11 #ifndef LIBXSMM_TYPEDEFS_H 12 #define LIBXSMM_TYPEDEFS_H 13 14 #include "libxsmm_macros.h" 15 16 /** Check ILP64 configuration for sanity. */ 17 #if !defined(LIBXSMM_ILP64) || (0 == LIBXSMM_ILP64 && defined(MKL_ILP64)) 18 # error "Inconsistent ILP64 configuration detected!" 19 #elif (0 != LIBXSMM_ILP64 && !defined(MKL_ILP64)) 20 # define MKL_ILP64 21 #endif 22 #if (0 != LIBXSMM_ILP64) 23 # define LIBXSMM_BLASINT_NBITS 64 24 # define LIBXSMM_BLASINT long long 25 #else /* LP64 */ 26 # define LIBXSMM_BLASINT_NBITS 32 27 # define LIBXSMM_BLASINT int 28 #endif 29 30 /** Generic prefetches; similar to LIBXSMM_PREFETCH_AUTO (libxsmm_frontend.h) */ 31 #define LIBXSMM_PREFETCH_SIGONLY 1 32 #define LIBXSMM_PREFETCH_NONE 0 33 34 /** Helper macro for type names. */ 35 #define LIBXSMM_TYPENAME(TYPE) LIBXSMM_STRINGIFY(LIBXSMM_CONCATENATE(LIBXSMM_TYPENAME_, TYPE)) 36 #define LIBXSMM_TYPENAME_double f64 37 #define LIBXSMM_TYPENAME_float f32 38 #define LIBXSMM_TYPENAME_libxsmm_bfloat16 bf16 39 #define LIBXSMM_TYPENAME_int i32 40 #define LIBXSMM_TYPENAME_short i16 41 #define LIBXSMM_TYPENAME_char i8 42 43 /** Helper macro for type information: INFO := { FP }. */ 44 #define LIBXSMM_TYPEINFO(TYPE, INFO) LIBXSMM_CONCATENATE4(LIBXSMM_TYPEINFO_, INFO, _, TYPE) 45 #define LIBXSMM_TYPEINFO_FP_double 1 46 #define LIBXSMM_TYPEINFO_FP_float 1 47 #define LIBXSMM_TYPEINFO_FP_libxsmm_bfloat16 1 48 #define LIBXSMM_TYPEINFO_FP_int 0 49 #define LIBXSMM_TYPEINFO_FP_short 0 50 #define LIBXSMM_TYPEINFO_FP_char 0 51 52 /** Helper macro for type postfixes. */ 53 #define LIBXSMM_TYPESYMBOL(TYPE) LIBXSMM_CONCATENATE(LIBXSMM_TYPESYMBOL_, TYPE) 54 #define LIBXSMM_TYPESYMBOL_double F64 55 #define LIBXSMM_TYPESYMBOL_float F32 56 #define LIBXSMM_TYPESYMBOL_libxsmm_bfloat16 BF16 57 #define LIBXSMM_TYPESYMBOL_int I32 58 #define LIBXSMM_TYPESYMBOL_short I16 59 #define LIBXSMM_TYPESYMBOL_char I8 60 61 #define LIBXSMM_TYPESIZE(ENUM) ( \ 62 ((int)(ENUM)) == LIBXSMM_DATATYPE_F64 ? 8 : ( \ 63 ((int)(ENUM)) == LIBXSMM_DATATYPE_F32 ? 4 : ( \ 64 ((int)(ENUM)) == LIBXSMM_DATATYPE_BF16 ? 2 : ( \ 65 ((int)(ENUM)) == LIBXSMM_DATATYPE_I32 ? 4 : ( \ 66 ((int)(ENUM)) == LIBXSMM_DATATYPE_I16 ? 2 : ( \ 67 ((int)(ENUM)) == LIBXSMM_DATATYPE_I8 ? 1 : ( \ 68 0/*invalid*/))))))) 69 70 /* Get input or output precision */ 71 #define LIBXSMM_GETENUM_INP(SRC) ((SRC) & 0x0F) 72 #define LIBXSMM_GETENUM_OUT(SRC) (0 == ((SRC) >> 4) ? LIBXSMM_GETENUM_INP(SRC) : ((SRC) >> 4)) 73 /* Get/Set input and output precision */ 74 #define LIBXSMM_GETENUM(INP, OUT) (((INP) == (OUT)) ? (INP) : ((INP) | ((OUT) << 4))) 75 #define LIBXSMM_SETENUM(DST, INP, OUT) DST = LIBXSMM_GETENUM(INP, OUT) 76 77 /* Construct an enumerator (libxsmm_datatype) from a built-in type (float, double, etc.). */ 78 #define LIBXSMM_DATATYPE(TYPE) LIBXSMM_CONCATENATE(LIBXSMM_DATATYPE_, LIBXSMM_TYPESYMBOL(TYPE)) 79 /* Construct a type-id from built-in input/output types (float, double, etc.). */ 80 #define LIBXSMM_DATATYPE2(ITYPE, OTYPE) LIBXSMM_GETENUM(LIBXSMM_DATATYPE(ITYPE), LIBXSMM_DATATYPE(OTYPE)) 81 82 /* Construct an enumerator (libxsmm_gemm_precision) from a built-in type (float, double, etc.). */ 83 #define LIBXSMM_GEMM_PRECISION(TYPE) LIBXSMM_CONCATENATE(LIBXSMM_GEMM_PRECISION_, LIBXSMM_TYPESYMBOL(TYPE)) 84 /* Construct GEMM-precision from built-in input/output types (float, double, etc.). */ 85 #define LIBXSMM_GEMM_PRECISION2(ITYPE, OTYPE) (libxsmm_gemm_precision)LIBXSMM_GETENUM( \ 86 LIBXSMM_GEMM_PRECISION(ITYPE), LIBXSMM_GEMM_PRECISION(OTYPE)) 87 88 /** Maximum size available to store a descriptor/blob (GEMM, MCOPY, TRANS, TRSM, TRMM). */ 89 #if !defined(LIBXSMM_DESCRIPTOR_MAXSIZE) 90 # define LIBXSMM_DESCRIPTOR_MAXSIZE 64 91 #endif 92 /** Size of the descriptor considered as unique signature. */ 93 #if !defined(LIBXSMM_DESCRIPTOR_SIGSIZE) 94 # define LIBXSMM_DESCRIPTOR_SIGSIZE LIBXSMM_DESCRIPTOR_MAXSIZE 95 #endif 96 97 98 /* Support for Bfloat16 */ 99 typedef unsigned short libxsmm_bfloat16; 100 101 LIBXSMM_EXTERN_C typedef union LIBXSMM_RETARGETABLE libxsmm_bfloat16_hp { 102 libxsmm_bfloat16 i[2]; 103 float f; 104 } libxsmm_bfloat16_hp; 105 106 #if defined(__cplusplus) 107 namespace tensorflow { struct bfloat16; } 108 #endif /*__cplusplus*/ 109 110 /** Integer type for LAPACK/BLAS (LP64: 32-bit, and ILP64: 64-bit). */ 111 typedef LIBXSMM_BLASINT libxsmm_blasint; 112 113 /** Type representing sufficient storage space for a GEMM handle. */ 114 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_gemm_blob { char data[128]; } libxsmm_gemm_blob; 115 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_gemm_handle libxsmm_gemm_handle; 116 117 /** Type representing sufficient storage space for descriptors (GEMM, TCOPY, MCOPY). */ 118 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_descriptor_blob { 119 char data[LIBXSMM_DESCRIPTOR_MAXSIZE]; 120 } libxsmm_descriptor_blob; 121 122 /** Structure storing arguments of GEMM-like routines. */ 123 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_gemm_descriptor libxsmm_gemm_descriptor; 124 /** Structure storing arguments of the matrix-copy routine. */ 125 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_mcopy_descriptor libxsmm_mcopy_descriptor; 126 /** Structure storing arguments of the matrix-eltw routine. */ 127 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_meltw_descriptor libxsmm_meltw_descriptor; 128 /** Structure storing arguments of the transpose routine. */ 129 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_trans_descriptor libxsmm_trans_descriptor; 130 /** Structure storing arguments of packed TRSM. */ 131 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_trsm_descriptor libxsmm_trsm_descriptor; 132 /** Structure storing arguments of packed TRMM. */ 133 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_trmm_descriptor libxsmm_trmm_descriptor; 134 /** Structure storing arguments of packed GETRF. */ 135 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_getrf_descriptor libxsmm_getrf_descriptor; 136 /** Structure storing arguments of packed GEMM. */ 137 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_pgemm_descriptor libxsmm_pgemm_descriptor; 138 139 /** Enumerates element/data types. */ 140 typedef enum libxsmm_datatype { 141 LIBXSMM_DATATYPE_F64, 142 LIBXSMM_DATATYPE_F32, 143 LIBXSMM_DATATYPE_BF16, 144 LIBXSMM_DATATYPE_I64, 145 LIBXSMM_DATATYPE_I32, 146 LIBXSMM_DATATYPE_I16, 147 LIBXSMM_DATATYPE_I8, 148 LIBXSMM_DATATYPE_UNSUPPORTED 149 } libxsmm_datatype; 150 151 /** Denotes the precision/data type of GEMM. */ 152 typedef enum libxsmm_gemm_precision { 153 LIBXSMM_GEMM_PRECISION_F64 = LIBXSMM_DATATYPE_F64, 154 LIBXSMM_GEMM_PRECISION_F32 = LIBXSMM_DATATYPE_F32, 155 LIBXSMM_GEMM_PRECISION_BF16 = LIBXSMM_DATATYPE_BF16, 156 LIBXSMM_GEMM_PRECISION_I32 = LIBXSMM_DATATYPE_I32, 157 LIBXSMM_GEMM_PRECISION_I16 = LIBXSMM_DATATYPE_I16, 158 LIBXSMM_GEMM_PRECISION_I8 = LIBXSMM_DATATYPE_I8 159 } libxsmm_gemm_precision; 160 161 typedef enum libxsmm_meltw_operation { 162 LIBXSMM_MELTW_OPERATION_NONE = 0, 163 LIBXSMM_MELTW_OPERATION_COPY = 1, 164 LIBXSMM_MELTW_OPERATION_ZERO = 2, 165 LIBXSMM_MELTW_OPERATION_ADD = 3, 166 LIBXSMM_MELTW_OPERATION_MUL = 4, 167 LIBXSMM_MELTW_OPERATION_RELU = 5, 168 LIBXSMM_MELTW_OPERATION_CVTFP32BF16 = 6, 169 LIBXSMM_MELTW_OPERATION_REDUCE = 7, 170 LIBXSMM_MELTW_OPERATION_SCALE = 8, 171 LIBXSMM_MELTW_OPERATION_CVTFP32BF16_ACT = 9, 172 LIBXSMM_MELTW_OPERATION_ACT_CVTFP32BF16 = 10, 173 LIBXSMM_MELTW_OPERATION_COLBIAS_ACT = 11 174 } libxsmm_meltw_operation; 175 176 typedef enum libxsmm_meltw_null_flags { 177 LIBXSMM_MELTW_FLAG_NONE = 0 178 } libxsmm_meltw_null_flags; 179 180 typedef enum libxsmm_meltw_redu_flags { 181 LIBXSMM_MELTW_FLAG_REDUCE_NONE = 0, 182 LIBXSMM_MELTW_FLAG_REDUCE_OP_ADD = 1, 183 LIBXSMM_MELTW_FLAG_REDUCE_OP_MAX = 2, 184 LIBXSMM_MELTW_FLAG_REDUCE_OP_MUL = 4, 185 LIBXSMM_MELTW_FLAG_REDUCE_ROWS = 8, 186 LIBXSMM_MELTW_FLAG_REDUCE_COLS = 16, 187 LIBXSMM_MELTW_FLAG_REDUCE_ELTS = 32, 188 LIBXSMM_MELTW_FLAG_REDUCE_ELTS_SQUARED = 64, 189 LIBXSMM_MELTW_FLAG_REDUCE_OP_ADD_ROWS = LIBXSMM_MELTW_FLAG_REDUCE_OP_ADD | LIBXSMM_MELTW_FLAG_REDUCE_ROWS, 190 LIBXSMM_MELTW_FLAG_REDUCE_OP_ADD_COLS = LIBXSMM_MELTW_FLAG_REDUCE_OP_ADD | LIBXSMM_MELTW_FLAG_REDUCE_COLS 191 } libxsmm_meltw_redu_flags; 192 193 typedef enum libxsmm_meltw_scal_flags { 194 LIBXSMM_MELTW_FLAG_SCALE_NONE = 0, 195 LIBXSMM_MELTW_FLAG_SCALE_MULT = 1, 196 LIBXSMM_MELTW_FLAG_SCALE_SHIFT = 2, 197 LIBXSMM_MELTW_FLAG_SCALE_ADD_BIAS = 4, 198 LIBXSMM_MELTW_FLAG_SCALE_ROWS = 8, 199 LIBXSMM_MELTW_FLAG_SCALE_COLS = 16, 200 LIBXSMM_MELTW_FLAG_SCALE_MULT_ROWS = LIBXSMM_MELTW_FLAG_SCALE_MULT | LIBXSMM_MELTW_FLAG_SCALE_ROWS, 201 LIBXSMM_MELTW_FLAG_SCALE_SHIFT_ROWS = LIBXSMM_MELTW_FLAG_SCALE_SHIFT | LIBXSMM_MELTW_FLAG_SCALE_ROWS, 202 LIBXSMM_MELTW_FLAG_SCALE_ADD_BIAS_ROWS = LIBXSMM_MELTW_FLAG_SCALE_ADD_BIAS | LIBXSMM_MELTW_FLAG_SCALE_ROWS, 203 LIBXSMM_MELTW_FLAG_SCALE_MULT_SHIFT_ROWS = LIBXSMM_MELTW_FLAG_SCALE_MULT | LIBXSMM_MELTW_FLAG_SCALE_SHIFT | LIBXSMM_MELTW_FLAG_SCALE_ROWS, 204 LIBXSMM_MELTW_FLAG_SCALE_ADD_BIAS_SHIFT_ROWS = LIBXSMM_MELTW_FLAG_SCALE_ADD_BIAS | LIBXSMM_MELTW_FLAG_SCALE_SHIFT | LIBXSMM_MELTW_FLAG_SCALE_ROWS, 205 LIBXSMM_MELTW_FLAG_SCALE_MULT_ADD_BIAS_ROWS = LIBXSMM_MELTW_FLAG_SCALE_ADD_BIAS | LIBXSMM_MELTW_FLAG_SCALE_MULT | LIBXSMM_MELTW_FLAG_SCALE_ROWS, 206 LIBXSMM_MELTW_FLAG_SCALE_MULT_SHIFT_ADD_BIAS_ROWS = LIBXSMM_MELTW_FLAG_SCALE_MULT | LIBXSMM_MELTW_FLAG_SCALE_SHIFT | LIBXSMM_MELTW_FLAG_SCALE_ADD_BIAS | LIBXSMM_MELTW_FLAG_SCALE_ROWS, 207 LIBXSMM_MELTW_FLAG_SCALE_MULT_COLS = LIBXSMM_MELTW_FLAG_SCALE_MULT | LIBXSMM_MELTW_FLAG_SCALE_COLS, 208 LIBXSMM_MELTW_FLAG_SCALE_SHIFT_COLS = LIBXSMM_MELTW_FLAG_SCALE_SHIFT | LIBXSMM_MELTW_FLAG_SCALE_COLS, 209 LIBXSMM_MELTW_FLAG_SCALE_ADD_BIAS_COLS = LIBXSMM_MELTW_FLAG_SCALE_ADD_BIAS | LIBXSMM_MELTW_FLAG_SCALE_COLS, 210 LIBXSMM_MELTW_FLAG_SCALE_MULT_SHIFT_COLS = LIBXSMM_MELTW_FLAG_SCALE_MULT | LIBXSMM_MELTW_FLAG_SCALE_SHIFT | LIBXSMM_MELTW_FLAG_SCALE_COLS, 211 LIBXSMM_MELTW_FLAG_SCALE_ADD_BIAS_SHIFT_COLS = LIBXSMM_MELTW_FLAG_SCALE_ADD_BIAS | LIBXSMM_MELTW_FLAG_SCALE_SHIFT | LIBXSMM_MELTW_FLAG_SCALE_COLS, 212 LIBXSMM_MELTW_FLAG_SCALE_MULT_ADD_BIAS_COLS = LIBXSMM_MELTW_FLAG_SCALE_ADD_BIAS | LIBXSMM_MELTW_FLAG_SCALE_MULT | LIBXSMM_MELTW_FLAG_SCALE_COLS, 213 LIBXSMM_MELTW_FLAG_SCALE_MULT_SHIFT_ADD_BIAS_COLS = LIBXSMM_MELTW_FLAG_SCALE_MULT | LIBXSMM_MELTW_FLAG_SCALE_SHIFT | LIBXSMM_MELTW_FLAG_SCALE_ADD_BIAS | LIBXSMM_MELTW_FLAG_SCALE_COLS 214 } libxsmm_meltw_scal_flags; 215 216 typedef enum libxsmm_meltw_cvta_flags { 217 LIBXSMM_MELTW_FLAG_CVTA_NONE = 0, 218 LIBXSMM_MELTW_FLAG_CVTA_FUSE_RELU = 1, 219 LIBXSMM_MELTW_FLAG_CVTA_FUSE_TANH = 2, 220 LIBXSMM_MELTW_FLAG_CVTA_FUSE_SIGM = 4 221 } libxsmm_meltw_cvta_flags; 222 223 typedef enum libxsmm_meltw_acvt_flags { 224 LIBXSMM_MELTW_FLAG_ACVT_NONE = 0, 225 LIBXSMM_MELTW_FLAG_ACVT_FUSE_TANH = 1, 226 LIBXSMM_MELTW_FLAG_ACVT_FUSE_SIGM = 2 227 } libxsmm_meltw_acvt_flags; 228 229 typedef enum libxsmm_meltw_cbiasact_flags { 230 LIBXSMM_MELTW_FLAG_CBIASACT_NONE = 0, 231 LIBXSMM_MELTW_FLAG_CBIASACT_COLBIAS = 1, 232 LIBXSMM_MELTW_FLAG_CBIASACT_ACT_RELU = 2, 233 LIBXSMM_MELTW_FLAG_CBIASACT_ACT_TANH = 4, 234 LIBXSMM_MELTW_FLAG_CBIASACT_ACT_SIGM = 8, 235 LIBXSMM_MELTW_FLAG_CBIASACT_ACT_GELU = 16, 236 LIBXSMM_MELTW_FLAG_CBIASACT_OVERWRITE_C = 32, 237 LIBXSMM_MELTW_FLAG_CBIASACT_COLBIAS_ACT_RELU = LIBXSMM_MELTW_FLAG_CBIASACT_COLBIAS | LIBXSMM_MELTW_FLAG_CBIASACT_ACT_RELU, 238 LIBXSMM_MELTW_FLAG_CBIASACT_COLBIAS_ACT_TANH = LIBXSMM_MELTW_FLAG_CBIASACT_COLBIAS | LIBXSMM_MELTW_FLAG_CBIASACT_ACT_TANH, 239 LIBXSMM_MELTW_FLAG_CBIASACT_COLBIAS_ACT_SIGM = LIBXSMM_MELTW_FLAG_CBIASACT_COLBIAS | LIBXSMM_MELTW_FLAG_CBIASACT_ACT_SIGM, 240 LIBXSMM_MELTW_FLAG_CBIASACT_COLBIAS_ACT_GELU = LIBXSMM_MELTW_FLAG_CBIASACT_COLBIAS | LIBXSMM_MELTW_FLAG_CBIASACT_ACT_GELU, 241 LIBXSMM_MELTW_FLAG_CBIASACT_COLBIAS_ACT_RELU_OVERWRITE_C = LIBXSMM_MELTW_FLAG_CBIASACT_COLBIAS | LIBXSMM_MELTW_FLAG_CBIASACT_ACT_RELU | LIBXSMM_MELTW_FLAG_CBIASACT_OVERWRITE_C, 242 LIBXSMM_MELTW_FLAG_CBIASACT_COLBIAS_ACT_TANH_OVERWRITE_C = LIBXSMM_MELTW_FLAG_CBIASACT_COLBIAS | LIBXSMM_MELTW_FLAG_CBIASACT_ACT_TANH | LIBXSMM_MELTW_FLAG_CBIASACT_OVERWRITE_C, 243 LIBXSMM_MELTW_FLAG_CBIASACT_COLBIAS_ACT_SIGM_OVERWRITE_C = LIBXSMM_MELTW_FLAG_CBIASACT_COLBIAS | LIBXSMM_MELTW_FLAG_CBIASACT_ACT_SIGM | LIBXSMM_MELTW_FLAG_CBIASACT_OVERWRITE_C, 244 LIBXSMM_MELTW_FLAG_CBIASACT_COLBIAS_ACT_GELU_OVERWRITE_C = LIBXSMM_MELTW_FLAG_CBIASACT_COLBIAS | LIBXSMM_MELTW_FLAG_CBIASACT_ACT_GELU | LIBXSMM_MELTW_FLAG_CBIASACT_OVERWRITE_C 245 } libxsmm_meltw_cbiasact_flags; 246 247 LIBXSMM_EXTERN_C typedef union LIBXSMM_RETARGETABLE libxsmm_xmelt_flags { 248 libxsmm_meltw_null_flags elt_null; 249 libxsmm_meltw_redu_flags elt_redu; 250 libxsmm_meltw_scal_flags elt_scal; 251 libxsmm_meltw_cvta_flags elt_cvta; 252 libxsmm_meltw_acvt_flags elt_acvt; 253 libxsmm_meltw_cbiasact_flags elt_cbiasact; 254 } libxsmm_xmelt_flags; 255 256 /** Flag enumeration which can be binary ORed. */ 257 typedef enum libxsmm_gemm_flags { 258 LIBXSMM_GEMM_FLAG_NONE = 0, 259 /** Transpose matrix A. */ 260 LIBXSMM_GEMM_FLAG_TRANS_A = 1, 261 /** Transpose matrix B. */ 262 LIBXSMM_GEMM_FLAG_TRANS_B = 2, 263 /** Transpose matrix A and B. */ 264 LIBXSMM_GEMM_FLAG_TRANS_AB = LIBXSMM_GEMM_FLAG_TRANS_A | LIBXSMM_GEMM_FLAG_TRANS_B, 265 #if 0 266 /** Alpha=0|1 */ 267 LIBXSMM_GEMM_FLAG_ALPHA_0 = 4, 268 /** Alpha=neg|pos */ 269 LIBXSMM_GEMM_FLAG_ALPHA_S = 8, 270 #endif 271 /** Beta=0|1 */ 272 LIBXSMM_GEMM_FLAG_BETA_0 = 16, 273 #if 0 274 /** Beta=neg|pos */ 275 LIBXSMM_GEMM_FLAG_BETA_S = 32, 276 #endif 277 /** Generate aligned load instructions. */ 278 LIBXSMM_GEMM_FLAG_ALIGN_A = 64, 279 /** Aligned load/store instructions. */ 280 LIBXSMM_GEMM_FLAG_ALIGN_C = 128, 281 /** Batch-reduce Ai * Bi. */ 282 LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS = 256, 283 /** Batch-reduce Ai * Bi. */ 284 LIBXSMM_GEMM_FLAG_BATCH_REDUCE_OFFSET = 512, 285 /** Batch-reduce Ai * Bi. */ 286 LIBXSMM_GEMM_FLAG_BATCH_REDUCE_STRIDE = 1024, 287 /** Aligned C matrix, but using NTS Hint when storing */ 288 LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT = 2176, 289 /* in case of integer GEMM, if A is unsigned */ 290 LIBXSMM_GEMM_FLAG_A_UNSIGNED = 4096, 291 /* in case of integer GEMM, if B is unsigned */ 292 LIBXSMM_GEMM_FLAG_B_UNSIGNED = 8192, 293 /* in case of integer GEMM, if C is unsigned */ 294 LIBXSMM_GEMM_FLAG_C_UNSIGNED = 16384, 295 /* in case of integer GEMM, if A and B are unsigned */ 296 LIBXSMM_GEMM_FLAG_AB_UNSIGNED = LIBXSMM_GEMM_FLAG_A_UNSIGNED | LIBXSMM_GEMM_FLAG_B_UNSIGNED, 297 /* for low precision we also require up-front packed formats "VNNI" for best performance, this flag indicates A */ 298 LIBXSMM_GEMM_FLAG_VNNI_A = 32768, 299 /* for low precision we also require up-front packed formats "VNNI" for best performance, this flag indicates B */ 300 LIBXSMM_GEMM_FLAG_VNNI_B = 65536, 301 /* combined types */ 302 LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BETA_0 = LIBXSMM_GEMM_FLAG_BETA_0 | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT, 303 LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BATCH_REDUCE_ADDRESS = LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT, 304 LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BETA_0_BATCH_REDUCE_ADDRESS = LIBXSMM_GEMM_FLAG_BETA_0 | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS, 305 LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BATCH_REDUCE_OFFSET = LIBXSMM_GEMM_FLAG_BATCH_REDUCE_OFFSET | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT, 306 LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BETA_0_BATCH_REDUCE_OFFSET = LIBXSMM_GEMM_FLAG_BETA_0 | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_OFFSET, 307 LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BATCH_REDUCE_STRIDE = LIBXSMM_GEMM_FLAG_BATCH_REDUCE_STRIDE | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT, 308 LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BETA_0_BATCH_REDUCE_STRIDE = LIBXSMM_GEMM_FLAG_BETA_0 | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_STRIDE, 309 LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BETA_0_A_UNSIGNED = LIBXSMM_GEMM_FLAG_BETA_0 | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_A_UNSIGNED, 310 LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BATCH_REDUCE_ADDRESS_A_UNSIGNED = LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_A_UNSIGNED, 311 LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BETA_0_BATCH_REDUCE_ADDRESS_A_UNSIGNED = LIBXSMM_GEMM_FLAG_BETA_0 | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS | LIBXSMM_GEMM_FLAG_A_UNSIGNED, 312 LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BATCH_REDUCE_OFFSET_A_UNSIGNED = LIBXSMM_GEMM_FLAG_BATCH_REDUCE_OFFSET | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_A_UNSIGNED, 313 LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BETA_0_BATCH_REDUCE_OFFSET_A_UNSIGNED = LIBXSMM_GEMM_FLAG_BETA_0 | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_OFFSET | LIBXSMM_GEMM_FLAG_A_UNSIGNED, 314 LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BATCH_REDUCE_STRIDE_A_UNSIGNED = LIBXSMM_GEMM_FLAG_BATCH_REDUCE_STRIDE | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_A_UNSIGNED, 315 LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BETA_0_BATCH_REDUCE_STRIDE_A_UNSIGNED = LIBXSMM_GEMM_FLAG_BETA_0 | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_STRIDE | LIBXSMM_GEMM_FLAG_A_UNSIGNED, 316 LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BETA_0_B_UNSIGNED = LIBXSMM_GEMM_FLAG_BETA_0 | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_B_UNSIGNED, 317 LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BATCH_REDUCE_ADDRESS_B_UNSIGNED = LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_B_UNSIGNED, 318 LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BETA_0_BATCH_REDUCE_ADDRESS_B_UNSIGNED = LIBXSMM_GEMM_FLAG_BETA_0 | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS | LIBXSMM_GEMM_FLAG_B_UNSIGNED, 319 LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BATCH_REDUCE_OFFSET_B_UNSIGNED = LIBXSMM_GEMM_FLAG_BATCH_REDUCE_OFFSET | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_B_UNSIGNED, 320 LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BETA_0_BATCH_REDUCE_OFFSET_B_UNSIGNED = LIBXSMM_GEMM_FLAG_BETA_0 | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_OFFSET | LIBXSMM_GEMM_FLAG_B_UNSIGNED, 321 LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BATCH_REDUCE_STRIDE_B_UNSIGNED = LIBXSMM_GEMM_FLAG_BATCH_REDUCE_STRIDE | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_B_UNSIGNED, 322 LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BETA_0_BATCH_REDUCE_STRIDE_B_UNSIGNED = LIBXSMM_GEMM_FLAG_BETA_0 | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_STRIDE | LIBXSMM_GEMM_FLAG_B_UNSIGNED, 323 LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BETA_0_AB_UNSIGNED = LIBXSMM_GEMM_FLAG_BETA_0 | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_AB_UNSIGNED, 324 LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BATCH_REDUCE_ADDRESS_AB_UNSIGNED = LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_AB_UNSIGNED, 325 LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BETA_0_BATCH_REDUCE_ADDRESS_AB_UNSIGNED = LIBXSMM_GEMM_FLAG_BETA_0 | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_ADDRESS | LIBXSMM_GEMM_FLAG_AB_UNSIGNED, 326 LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BATCH_REDUCE_OFFSET_AB_UNSIGNED = LIBXSMM_GEMM_FLAG_BATCH_REDUCE_OFFSET | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_AB_UNSIGNED, 327 LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BETA_0_BATCH_REDUCE_OFFSET_AB_UNSIGNED = LIBXSMM_GEMM_FLAG_BETA_0 | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_OFFSET | LIBXSMM_GEMM_FLAG_AB_UNSIGNED, 328 LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BATCH_REDUCE_STRIDE_AB_UNSIGNED = LIBXSMM_GEMM_FLAG_BATCH_REDUCE_STRIDE | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_AB_UNSIGNED, 329 LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT_BETA_0_BATCH_REDUCE_STRIDE_AB_UNSIGNED = LIBXSMM_GEMM_FLAG_BETA_0 | LIBXSMM_GEMM_FLAG_ALIGN_C_NTS_HINT | LIBXSMM_GEMM_FLAG_BATCH_REDUCE_STRIDE | LIBXSMM_GEMM_FLAG_AB_UNSIGNED, 330 /** Marker flag; do not use. */ 331 LIBXSMM_GEMM_FLAG_INVALID = 131072 332 } libxsmm_gemm_flags; 333 334 /** Flag enumeration which can be binary ORed. */ 335 typedef enum libxsmm_gemm_handle_flags { 336 LIBXSMM_GEMM_HANDLE_FLAG_AUTO = 0, 337 LIBXSMM_GEMM_HANDLE_FLAG_COPY_A = 1, 338 LIBXSMM_GEMM_HANDLE_FLAG_COPY_B = 2, 339 LIBXSMM_GEMM_HANDLE_FLAG_COPY_C = 4 340 } libxsmm_gemm_handle_flags; 341 342 /** Auto-batch flags (can be ORed) applicable to mmbatch_begin/mmbatch_end. */ 343 typedef enum libxsmm_mmbatch_flags { 344 /** Handle recorded batch unsynchronized-parallel. */ 345 LIBXSMM_MMBATCH_FLAG_DEFAULT = LIBXSMM_GEMM_FLAG_INVALID * 0, 346 /** Synchronize among C matrices. */ 347 LIBXSMM_MMBATCH_FLAG_SYNCHRONIZED = LIBXSMM_GEMM_FLAG_INVALID * 1, 348 /** Handle recorded batch sequentially. */ 349 LIBXSMM_MMBATCH_FLAG_SEQUENTIAL = LIBXSMM_GEMM_FLAG_INVALID * 2, 350 /** Only record a statistic of potential SMMs. */ 351 LIBXSMM_MMBATCH_FLAG_STATISTIC = LIBXSMM_GEMM_FLAG_INVALID * 4 352 } libxsmm_mmbatch_flags; 353 354 /** Enumeration of the available prefetch strategies. */ 355 typedef enum libxsmm_gemm_prefetch_type { 356 /** No prefetching and no prefetch fn. signature. */ 357 LIBXSMM_GEMM_PREFETCH_NONE = LIBXSMM_PREFETCH_NONE, 358 /** Only function prefetch signature. */ 359 LIBXSMM_GEMM_PREFETCH_SIGONLY = LIBXSMM_PREFETCH_SIGONLY, 360 /** Prefetch PA using accesses to A. */ 361 LIBXSMM_GEMM_PREFETCH_AL2 = 2, 362 /** Prefetch PA (aggressive). */ 363 LIBXSMM_GEMM_PREFETCH_BL2_VIA_C = 4, 364 /** Prefetch A ahead. */ 365 LIBXSMM_GEMM_PREFETCH_AL2_AHEAD = 8, 366 LIBXSMM_GEMM_PREFETCH_AL2BL2_VIA_C = LIBXSMM_GEMM_PREFETCH_BL2_VIA_C | LIBXSMM_GEMM_PREFETCH_AL2, 367 LIBXSMM_GEMM_PREFETCH_AL2BL2_VIA_C_AHEAD = LIBXSMM_GEMM_PREFETCH_BL2_VIA_C | LIBXSMM_GEMM_PREFETCH_AL2_AHEAD, 368 /** Backward compatibility: AL2CL2BL2_VIA_C is an alias for AL2BL2_VIA_C (Eigen library). */ 369 LIBXSMM_PREFETCH_AL2CL2BL2_VIA_C = LIBXSMM_GEMM_PREFETCH_AL2BL2_VIA_C, 370 /** Current B into L1. */ 371 LIBXSMM_GEMM_PREFETCH_BL1 = 16 372 } libxsmm_gemm_prefetch_type; 373 374 /** Flag enumeration which can be binary ORed. */ 375 typedef enum libxsmm_matcopy_flags { 376 LIBXSMM_MATCOPY_FLAG_DEFAULT = 0, 377 /** If set, then use zero matrix as source */ 378 LIBXSMM_MATCOPY_FLAG_ZERO_SOURCE = 1 379 } libxsmm_matcopy_flags; 380 381 /** Determines the kernel kind. */ 382 typedef enum libxsmm_kernel_kind { 383 /** Matrix multiplication kernel */ 384 LIBXSMM_KERNEL_KIND_MATMUL = 0, 385 /** Matcopy kernel kind */ 386 LIBXSMM_KERNEL_KIND_MCOPY = 1, 387 /** Mateltw kernel kind */ 388 LIBXSMM_KERNEL_KIND_MELTW = 2, 389 /** Transpose kernel kind */ 390 LIBXSMM_KERNEL_KIND_TRANS = 3, 391 /** GEMM/packed kernel kind */ 392 LIBXSMM_KERNEL_KIND_PGEMM = 4, 393 /** GEMM/packed kernel kind */ 394 LIBXSMM_KERNEL_KIND_GETRF = 5, 395 /** TRMM kernel kind */ 396 LIBXSMM_KERNEL_KIND_TRMM = 6, 397 /** TRSM kernel kind */ 398 LIBXSMM_KERNEL_KIND_TRSM = 7, 399 /** User-defined kernels */ 400 LIBXSMM_KERNEL_KIND_USER = 8, 401 /** Not a JIT kernel */ 402 LIBXSMM_KERNEL_UNREGISTERED = 9 403 } libxsmm_kernel_kind; 404 405 typedef enum libxsmm_dnn_tensor_format { 406 /* use LIBXSMM internal format, we need to copy data into that */ 407 LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM = 1, 408 /* use NHWC format internally, this allows no-copy operations */ 409 LIBXSMM_DNN_TENSOR_FORMAT_NHWC = 2, 410 /* use NCHW format internally, this will include shadow copies, not preferred */ 411 LIBXSMM_DNN_TENSOR_FORMAT_NCHW = 4, 412 /* use RSCK format internally, this allows no-copy operations */ 413 LIBXSMM_DNN_TENSOR_FORMAT_RSCK = 8, 414 /* use KCRS format internally, this will include shadow copies, not preferred */ 415 LIBXSMM_DNN_TENSOR_FORMAT_KCRS = 16, 416 LIBXSMM_DNN_TENSOR_FORMAT_CK = 32, 417 LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED = 64, 418 LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED = 128, 419 LIBXSMM_DNN_TENSOR_FORMAT_NC = 256 420 } libxsmm_dnn_tensor_format; 421 422 /** Denotes the element/pixel type of an image/channel. */ 423 typedef enum libxsmm_dnn_datatype { 424 LIBXSMM_DNN_DATATYPE_F64 = LIBXSMM_DATATYPE_F64, 425 LIBXSMM_DNN_DATATYPE_F32 = LIBXSMM_DATATYPE_F32, 426 LIBXSMM_DNN_DATATYPE_BF16 = LIBXSMM_DATATYPE_BF16, 427 LIBXSMM_DNN_DATATYPE_I32 = LIBXSMM_DATATYPE_I32, 428 LIBXSMM_DNN_DATATYPE_I16 = LIBXSMM_DATATYPE_I16, 429 LIBXSMM_DNN_DATATYPE_I8 = LIBXSMM_DATATYPE_I8 430 } libxsmm_dnn_datatype; 431 432 typedef enum libxsmm_dnn_conv_option { 433 /* we get default settings */ 434 LIBXSMM_DNN_CONV_OPTION_NONE = 0, 435 /* overwrite results buffer (set it to zero before running the operations) */ 436 LIBXSMM_DNN_CONV_OPTION_OVERWRITE = 1, 437 /* external filter transpose to bwd convolutions */ 438 LIBXSMM_DNN_CONV_OPTION_BWD_NO_FILTER_TRANSPOSE = 2, 439 /* compound types */ 440 LIBXSMM_DNN_CONV_OPTION_BWD_NO_FILTER_TRANSPOSE_OVERWRITE = LIBXSMM_DNN_CONV_OPTION_OVERWRITE | LIBXSMM_DNN_CONV_OPTION_BWD_NO_FILTER_TRANSPOSE 441 } libxsmm_dnn_conv_option; 442 443 typedef enum libxsmm_dnn_fusedbatchnorm_fuse_order { 444 /* the fuse order is: 1. BN, 2. element-wise 3. RELU */ 445 LIBXSMM_DNN_FUSEDBN_ORDER_BN_ELTWISE_RELU = 0 446 } libxsmm_dnn_fusedbatchnorm_fuse_order; 447 448 typedef enum libxsmm_dnn_fusedbatchnorm_fuse_op { 449 /* the fuse order is: 1. BN, 2. element-wise 3. RELU */ 450 LIBXSMM_DNN_FUSEDBN_OPS_BN = 1, 451 LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE = 2, 452 LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS = 4, 453 LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED = 8, 454 LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE = 16, 455 LIBXSMM_DNN_FUSEDBN_OPS_RELU = 32, 456 LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK = 64, 457 LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU = LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE | LIBXSMM_DNN_FUSEDBN_OPS_RELU, 458 LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE_RELU_WITH_MASK = LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE | LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK, 459 LIBXSMM_DNN_FUSEDBN_OPS_BN_ELTWISE = LIBXSMM_DNN_FUSEDBN_OPS_BN | LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE, 460 LIBXSMM_DNN_FUSEDBN_OPS_BN_RELU = LIBXSMM_DNN_FUSEDBN_OPS_BN | LIBXSMM_DNN_FUSEDBN_OPS_RELU, 461 LIBXSMM_DNN_FUSEDBN_OPS_BN_RELU_WITH_MASK = LIBXSMM_DNN_FUSEDBN_OPS_BN | LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK, 462 LIBXSMM_DNN_FUSEDBN_OPS_BN_ELTWISE_RELU = LIBXSMM_DNN_FUSEDBN_OPS_BN | LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE | LIBXSMM_DNN_FUSEDBN_OPS_RELU, 463 LIBXSMM_DNN_FUSEDBN_OPS_BN_ELTWISE_RELU_WITH_MASK = LIBXSMM_DNN_FUSEDBN_OPS_BN | LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE | LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK, 464 LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE_ELTWISE = LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE | LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE, 465 LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE_RELU = LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE | LIBXSMM_DNN_FUSEDBN_OPS_RELU, 466 LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE_RELU_WITH_MASK = LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE | LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK, 467 LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE_ELTWISE_RELU = LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE | LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE | LIBXSMM_DNN_FUSEDBN_OPS_RELU, 468 LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE_ELTWISE_RELU_WITH_MASK = LIBXSMM_DNN_FUSEDBN_OPS_BNSCALE | LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE | LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK, 469 LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_ELTWISE = LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS | LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE, 470 LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_RELU = LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS | LIBXSMM_DNN_FUSEDBN_OPS_RELU, 471 LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_RELU_WITH_MASK = LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS | LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK, 472 LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_ELTWISE_RELU = LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS | LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE | LIBXSMM_DNN_FUSEDBN_OPS_RELU, 473 LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_ELTWISE_RELU_WITH_MASK = LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS | LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE | LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK, 474 LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED_ELTWISE = LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED | LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE, 475 LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED_RELU = LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED | LIBXSMM_DNN_FUSEDBN_OPS_RELU, 476 LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED_RELU_WITH_MASK = LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED | LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK, 477 LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED_ELTWISE_RELU = LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED | LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE | LIBXSMM_DNN_FUSEDBN_OPS_RELU, 478 LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED_ELTWISE_RELU_WITH_MASK = LIBXSMM_DNN_FUSEDBN_OPS_BNSTATS_NORED | LIBXSMM_DNN_FUSEDBN_OPS_ELTWISE | LIBXSMM_DNN_FUSEDBN_OPS_RELU_WITH_MASK 479 } libxsmm_dnn_fusedbatchnorm_fuse_op; 480 481 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_dnn_fusedbatchnorm_desc { 482 int partN; /* number of images in mini-batch, used for all elementwise computations */ 483 int fullN; /* number of images in mini-batch, used for statistics computations */ 484 int C; /* number of input feature maps */ 485 int H; /* height of input image */ 486 int W; /* width of input image */ 487 int u; /* vertical stride */ 488 int v; /* horizontal stride */ 489 int pad_h_in; /* height of physical zero-padding in input buffer */ 490 int pad_w_in; /* width of physical zero-padding in input buffer */ 491 int pad_h_out; /* height of physical zero-padding in output buffer */ 492 int pad_w_out; /* width of physical zero-padding in output buffer */ 493 int threads; /* number of threads used */ 494 libxsmm_dnn_datatype datatype_in; /* datatype used for all input related buffers */ 495 libxsmm_dnn_datatype datatype_out; /* datatype used for all output related buffers */ 496 libxsmm_dnn_datatype datatype_stats; /* datatype used for all stats related buffers */ 497 libxsmm_dnn_tensor_format buffer_format; /* format which is for activation buffers */ 498 libxsmm_dnn_fusedbatchnorm_fuse_order fuse_order; /* additional options */ 499 libxsmm_dnn_fusedbatchnorm_fuse_op fuse_ops; /* used ops into convolutions */ 500 } libxsmm_dnn_fusedbatchnorm_desc; 501 502 typedef enum libxsmm_dnn_fusedgroupnorm_fuse_order { 503 /* the fuse order is: 1. BN, 2. element-wise 3. RELU */ 504 LIBXSMM_DNN_FUSEDGN_ORDER_GN_ELTWISE_RELU = 0 505 } libxsmm_dnn_fusedgroupnorm_fuse_order; 506 507 typedef enum libxsmm_dnn_fusedgroupnorm_fuse_op { 508 /* the fuse order is: 1. GN, 2. element-wise 3. RELU */ 509 LIBXSMM_DNN_FUSEDGN_OPS_GN = 1, 510 LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE = 2, 511 LIBXSMM_DNN_FUSEDGN_OPS_RELU = 4, 512 LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK = 8, 513 LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU = LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE | LIBXSMM_DNN_FUSEDGN_OPS_RELU, 514 LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE_RELU_WITH_MASK = LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE | LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK, 515 LIBXSMM_DNN_FUSEDGN_OPS_GN_ELTWISE = LIBXSMM_DNN_FUSEDGN_OPS_GN | LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE, 516 LIBXSMM_DNN_FUSEDGN_OPS_GN_RELU = LIBXSMM_DNN_FUSEDGN_OPS_GN | LIBXSMM_DNN_FUSEDGN_OPS_RELU, 517 LIBXSMM_DNN_FUSEDGN_OPS_GN_RELU_WITH_MASK = LIBXSMM_DNN_FUSEDGN_OPS_GN | LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK, 518 LIBXSMM_DNN_FUSEDGN_OPS_GN_ELTWISE_RELU = LIBXSMM_DNN_FUSEDGN_OPS_GN | LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE | LIBXSMM_DNN_FUSEDGN_OPS_RELU, 519 LIBXSMM_DNN_FUSEDGN_OPS_GN_ELTWISE_RELU_WITH_MASK = LIBXSMM_DNN_FUSEDGN_OPS_GN | LIBXSMM_DNN_FUSEDGN_OPS_ELTWISE | LIBXSMM_DNN_FUSEDGN_OPS_RELU_WITH_MASK 520 } libxsmm_dnn_fusedgroupnorm_fuse_op; 521 522 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_dnn_fusedgroupnorm_desc { 523 int N; /* number of images in mini-batch */ 524 int G; /* groups of channels to norm */ 525 int C; /* number of input feature maps */ 526 int H; /* height of input image */ 527 int W; /* width of input image */ 528 int u; /* vertical stride */ 529 int v; /* horizontal stride */ 530 int pad_h_in; /* height of physical zero-padding in input buffer */ 531 int pad_w_in; /* width of physical zero-padding in input buffer */ 532 int pad_h_out; /* height of physical zero-padding in output buffer */ 533 int pad_w_out; /* width of physical zero-padding in output buffer */ 534 int threads; /* number of threads used */ 535 libxsmm_dnn_datatype datatype_in; /* datatype used for all input related buffers */ 536 libxsmm_dnn_datatype datatype_out; /* datatype used for all output related buffers */ 537 libxsmm_dnn_datatype datatype_stats; /* datatype used for all stats related buffers */ 538 libxsmm_dnn_tensor_format buffer_format; /* format which is for activation buffers */ 539 libxsmm_dnn_fusedgroupnorm_fuse_order fuse_order; /* additional options */ 540 libxsmm_dnn_fusedgroupnorm_fuse_op fuse_ops; /* used ops into convolutions */ 541 } libxsmm_dnn_fusedgroupnorm_desc; 542 543 /** argument struct for matrix-eltwise: copy */ 544 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_meltw_copy_param { 545 const void* in_ptr; /* input pointer */ 546 void* out_ptr; /* output pointer */ 547 } libxsmm_meltw_copy_param; 548 549 /** argument struct for matrix-eltwise: zero */ 550 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_meltw_zero_param { 551 const void* in_ptr; /* input pointer */ 552 void* out_ptr; /* output pointer */ 553 } libxsmm_meltw_zero_param; 554 555 /** argument struct for matrix-eltwise: add */ 556 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_meltw_add_param { 557 const void* in_ptr; /* input pointer */ 558 void* out_ptr; /* output pointer */ 559 } libxsmm_meltw_add_param; 560 561 /** argument struct for matrix-eltwise: mul */ 562 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_meltw_mul_param { 563 const void* in_ptr; /* input pointer */ 564 void* out_ptr; /* output pointer */ 565 } libxsmm_meltw_mul_param; 566 567 /** argument struct for matrix-eltwise: relu */ 568 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_meltw_relu_param { 569 const void* in_ptr; /* input pointer */ 570 void* mask_ptr; /* pointer to load/store ReLU mask */ 571 void* out_ptr; /* output pointer */ 572 } libxsmm_meltw_relu_param; 573 574 /** argument struct for matrix-eltwise: cvtfp32bf16 */ 575 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_meltw_cvtfp32bf16_param { 576 const void* in_ptr; /* input pointer */ 577 void* out_ptr; /* output pointer */ 578 } libxsmm_meltw_cvtfp32bf16_param; 579 580 /** argument struct for matrix-eltwise: cvtfp32bf16_act */ 581 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_meltw_cvtfp32bf16_act_param { 582 const void* in_ptr; /* input pointer */ 583 void* out_ptr; /* output pointer */ 584 void* actstore_ptr; /* output pointer for activation if it is fused into the convert */ 585 } libxsmm_meltw_cvtfp32bf16_act_param; 586 587 /** argument struct for matrix-eltwise: act_cvtfp32bf16 */ 588 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_meltw_act_cvtfp32bf16_param { 589 const void* in_ptr; /* input pointer */ 590 void* out_ptr; /* output pointer */ 591 void* actstore_ptr; /* output pointer for activation if it is fused into the convert */ 592 } libxsmm_meltw_act_cvtfp32bf16_param; 593 594 /** argument struct for matrix-eltwise: reduce */ 595 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_meltw_reduce_param { 596 const void* in_ptr; /* input pointer */ 597 void* out_ptr_0; /* output pointer */ 598 void* out_ptr_1; /* output pointer */ 599 } libxsmm_meltw_reduce_param; 600 601 /** argument struct for matrix-eltwise: scale */ 602 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_meltw_scale_param { 603 const void* in_ptr; /* input pointer */ 604 const void* shift_vals_ptr; /* pointer to shift values array */ 605 const void* scale_vals_ptr; /* pointer to scale values array */ 606 const void* bias_vals_ptr; /* pointer to bias values array*/ 607 void* out_ptr; /* output pointer */ 608 } libxsmm_meltw_scale_param; 609 610 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_meltw_cbiasact_param { 611 const void* in_ptr; /* input pointer */ 612 const void* bias_ptr; /* col-bias pointer */ 613 void* mask_ptr; /* pointer to load/store ReLU mask */ 614 void* out_ptr; /* output pointer */ 615 } libxsmm_meltw_cbiasact_param; 616 617 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_meltw_cbiasact_gemm_param { 618 const void* bias_ptr; /* optional, col-bias pointer */ 619 void* out_ptr; /* optional, pointer to output after eltwise (contains mask in case of ReLU); */ 620 /* Need for some activation functions, assumed to have the same shape as C matrix, */ 621 /* may not be set when OVERWRITE_C option is chosen */ 622 /* If OVERWRITE_C is false: out_ptr contains the post-act output, C has the pre-act output */ 623 /* If OVERWRITE_C is true: C contains post-act output, out_ptr contains the ReLU mask (only when act was ReLU) for other act unused */ 624 } libxsmm_meltw_cbiasact_gemm_param; 625 626 /** Specialized function for matrix-eltw (weak-typed). */ 627 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_meltwfunction_copy)(const libxsmm_meltw_copy_param* in_struct); 628 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_meltwfunction_zero)(const libxsmm_meltw_zero_param* in_struct); 629 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_meltwfunction_add)(const libxsmm_meltw_add_param* in_struct); 630 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_meltwfunction_mul)(const libxsmm_meltw_mul_param* in_struct); 631 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_meltwfunction_relu)(const libxsmm_meltw_relu_param* in_struct); 632 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_meltwfunction_cvtfp32bf16)(const libxsmm_meltw_cvtfp32bf16_param* in_struct); 633 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_meltwfunction_reduce)(const libxsmm_meltw_reduce_param* in_struct); 634 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_meltwfunction_scale)(const libxsmm_meltw_scale_param* in_struct); 635 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_meltwfunction_cvtfp32bf16_act)(const libxsmm_meltw_cvtfp32bf16_act_param* in_struct); 636 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_meltwfunction_act_cvtfp32bf16)(const libxsmm_meltw_act_cvtfp32bf16_param* in_struct); 637 638 LIBXSMM_EXTERN_C typedef union LIBXSMM_RETARGETABLE libxsmm_xmeltwfunction { 639 void (*xmeltw)(const void* in_struct); 640 libxsmm_meltwfunction_copy meltw_copy; libxsmm_meltwfunction_zero meltw_zero; 641 libxsmm_meltwfunction_add meltw_add; libxsmm_meltwfunction_mul meltw_mul; 642 libxsmm_meltwfunction_relu meltw_relu; libxsmm_meltwfunction_cvtfp32bf16 meltw_cvtfp32bf16; 643 libxsmm_meltwfunction_reduce meltw_reduce; libxsmm_meltwfunction_scale meltw_scale; 644 libxsmm_meltwfunction_cvtfp32bf16_act meltw_cvtfp32bf16_act; 645 libxsmm_meltwfunction_act_cvtfp32bf16 meltw_act_cvtfp32bf16; 646 } libxsmm_xmeltwfunction; 647 648 /** Specialized function with fused alpha and beta arguments, and optional prefetch locations (double-precision). */ 649 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_dmmfunction)(const double* a, const double* b, double* c, ...); 650 /** Specialized function with fused alpha and beta arguments, and optional prefetch locations (single-precision). */ 651 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_smmfunction)(const float* a, const float* b, float* c, ...); 652 /** Specialized function with fused alpha and beta arguments, and optional prefetch locations (bf16, fp32-accumulate). */ 653 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_bsmmfunction)(const libxsmm_bfloat16* a, const libxsmm_bfloat16* b, float* c, ...); 654 /** Specialized function with fused alpha and beta arguments, and optional prefetch locations (bf16, fp32-accumulate). */ 655 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_bmmfunction)(const libxsmm_bfloat16* a, const libxsmm_bfloat16* b, libxsmm_bfloat16* c, ...); 656 /** Specialized function with fused alpha and beta arguments, and optional prefetch locations (low-precision). */ 657 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_wimmfunction)(const short* a, const short* b, int* c, ...); 658 /** Specialized function with fused alpha and beta arguments, and optional prefetch locations (int8, int32 accumulate). */ 659 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_ssbimmfunction)(const char* a, const char* b, int* c, ...); 660 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_usbimmfunction)(const unsigned char* a, const char* b, int* c, ...); 661 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_subimmfunction)(const char* a, const unsigned char* b, int* c, ...); 662 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_uubimmfunction)(const unsigned char* a, const unsigned char* b, int* c, ...); 663 /** Specialized function with fused alpha and beta arguments, and optional prefetch locations (int8, int32 accumulate, int8 downconvert). */ 664 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_sububmmfunction)(const char* a, const unsigned char* b, unsigned char* c, float* scf, ...); 665 666 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_dmmfunction_reducebatch_addr)(const double** a, const double** b, double* c, const unsigned long long* count, ...); 667 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_smmfunction_reducebatch_addr)(const float** a, const float** b, float* c, const unsigned long long* count, ...); 668 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_bsmmfunction_reducebatch_addr)(const libxsmm_bfloat16** a, const libxsmm_bfloat16** b, float* c, const unsigned long long* count, ...); 669 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_bmmfunction_reducebatch_addr)(const libxsmm_bfloat16** a, const libxsmm_bfloat16** b, libxsmm_bfloat16* c, const unsigned long long* count, ...); 670 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_wimmfunction_reducebatch_addr)(const short** a, const short** b, int* c, const unsigned long long* count, ...); 671 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_ssbimmfunction_reducebatch_addr)(const char** a, const char** b, int* c, const unsigned long long* count, ...); 672 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_usbimmfunction_reducebatch_addr)(const unsigned char** a, const char** b, int* c, const unsigned long long* count, ...); 673 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_subimmfunction_reducebatch_addr)(const char** a, const unsigned char** b, int* c, const unsigned long long* count, ...); 674 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_uubimmfunction_reducebatch_addr)(const unsigned char** a, const unsigned char** b, int* c, const unsigned long long* count, ...); 675 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_sububmmfunction_reducebatch_addr)(const char** a, const unsigned char** b, unsigned char* c, const unsigned long long* count, float* scf, ...); 676 677 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_dmmfunction_reducebatch_offs)(const double* a, const double* b, double* c, const unsigned long long* count, const unsigned long long* a_offs, const unsigned long long* b_offs, ...); 678 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_smmfunction_reducebatch_offs)(const float* a, const float* b, float* c, const unsigned long long* count, const unsigned long long* a_offs, const unsigned long long* b_offs, ...); 679 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_bsmmfunction_reducebatch_offs)(const libxsmm_bfloat16* a, const libxsmm_bfloat16* b, float* c, const unsigned long long* count, const unsigned long long* a_offs, const unsigned long long* b_offs, ...); 680 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_bmmfunction_reducebatch_offs)(const libxsmm_bfloat16* a, const libxsmm_bfloat16* b, libxsmm_bfloat16* c, const unsigned long long* count, const unsigned long long* a_offs, const unsigned long long* b_offs, ...); 681 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_wimmfunction_reducebatch_offs)(const short* a, const short* b, int* c, const unsigned long long* count, const unsigned long long* a_offs, const unsigned long long* b_offs, ...); 682 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_ssbimmfunction_reducebatch_offs)(const char* a, const char* b, int* c, const unsigned long long* count, const unsigned long long* a_offs, const unsigned long long* b_offs, ...); 683 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_usbimmfunction_reducebatch_offs)(const unsigned char* a, const char* b, int* c, const unsigned long long* count, const unsigned long long* a_offs, const unsigned long long* b_offs, ...); 684 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_subimmfunction_reducebatch_offs)(const char* a, const unsigned char* b, int* c, const unsigned long long* count, const unsigned long long* a_offs, const unsigned long long* b_offs, ...); 685 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_uubimmfunction_reducebatch_offs)(const unsigned char* a, const unsigned char* b, int* c, const unsigned long long* count, const unsigned long long* a_offs, const unsigned long long* b_offs, ...); 686 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_sububmmfunction_reducebatch_offs)(const char* a, const unsigned char* b, unsigned char* c, const unsigned long long* count, const unsigned long long* a_offs, const unsigned long long* b_offs, float* scf, ...); 687 688 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_dmmfunction_reducebatch_strd)(const double* a, const double* b, double* c, const unsigned long long* count, ...); 689 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_smmfunction_reducebatch_strd)(const float* a, const float* b, float* c, const unsigned long long* count, ...); 690 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_bsmmfunction_reducebatch_strd)(const libxsmm_bfloat16* a, const libxsmm_bfloat16* b, float* c, const unsigned long long* count, ...); 691 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_bmmfunction_reducebatch_strd)(const libxsmm_bfloat16* a, const libxsmm_bfloat16* b, libxsmm_bfloat16* c, const unsigned long long* count, ...); 692 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_wimmfunction_reducebatch_strd)(const short* a, const short* b, int* c, const unsigned long long* count, ...); 693 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_ssbimmfunction_reducebatch_strd)(const char* a, const char* b, int* c, const unsigned long long* count, ...); 694 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_usbimmfunction_reducebatch_strd)(const unsigned char* a, const char* b, int* c, const unsigned long long* count, ...); 695 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_subimmfunction_reducebatch_strd)(const char* a, const unsigned char* b, int* c, const unsigned long long* count, ...); 696 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_uubimmfunction_reducebatch_strd)(const unsigned char* a, const unsigned char* b, int* c, const unsigned long long* count, ...); 697 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_sububmmfunction_reducebatch_strd)(const char* a, const unsigned char* b, unsigned char* c, const unsigned long long* count, float* scf, ...); 698 699 /** Function type which is either libxsmm_smmfunction or libxsmm_dmmfunction (weak-typed). */ 700 LIBXSMM_EXTERN_C typedef union LIBXSMM_RETARGETABLE libxsmm_xmmfunction { 701 void (*xmm)(const void* a, const void* b, void* c, ...); 702 void (*xbm)(const void** a, const void** b, void* c, const unsigned long long* count, ...); 703 libxsmm_dmmfunction dmm; libxsmm_smmfunction smm; libxsmm_wimmfunction wimm; libxsmm_bsmmfunction bsmm; libxsmm_bmmfunction bmm; 704 libxsmm_ssbimmfunction ssbimm; libxsmm_usbimmfunction usbimm; libxsmm_subimmfunction subimm; libxsmm_uubimmfunction uubimm; libxsmm_sububmmfunction sububmm; 705 libxsmm_dmmfunction_reducebatch_addr dmra; libxsmm_smmfunction_reducebatch_addr smra; libxsmm_bsmmfunction_reducebatch_addr bsmra; libxsmm_bmmfunction_reducebatch_addr bmra; 706 libxsmm_wimmfunction_reducebatch_addr wimra; libxsmm_ssbimmfunction_reducebatch_addr ssbimra; libxsmm_usbimmfunction_reducebatch_addr usbimra; libxsmm_subimmfunction_reducebatch_addr subimra; libxsmm_uubimmfunction_reducebatch_addr uubimra; 707 libxsmm_sububmmfunction_reducebatch_addr sububmra; 708 libxsmm_dmmfunction_reducebatch_offs dmro; libxsmm_smmfunction_reducebatch_offs smro; libxsmm_bsmmfunction_reducebatch_offs bsmro; libxsmm_bmmfunction_reducebatch_offs bmro; 709 libxsmm_wimmfunction_reducebatch_offs wimro; libxsmm_ssbimmfunction_reducebatch_offs ssbimro; libxsmm_usbimmfunction_reducebatch_offs usbimro; libxsmm_subimmfunction_reducebatch_offs subimro; libxsmm_uubimmfunction_reducebatch_offs uubimro; 710 libxsmm_sububmmfunction_reducebatch_offs sububmro; 711 libxsmm_dmmfunction_reducebatch_strd dmrs; libxsmm_smmfunction_reducebatch_strd smrs; libxsmm_bsmmfunction_reducebatch_strd bsmrs; libxsmm_bmmfunction_reducebatch_strd bmrs; 712 libxsmm_wimmfunction_reducebatch_strd wimrs; libxsmm_ssbimmfunction_reducebatch_strd ssbimrs; libxsmm_usbimmfunction_reducebatch_strd usbimrs; libxsmm_subimmfunction_reducebatch_strd subimrs; libxsmm_uubimmfunction_reducebatch_strd uubimrs; 713 libxsmm_sububmmfunction_reducebatch_strd sububmrs; 714 } libxsmm_xmmfunction; 715 716 /** Specialized function for matrix-copy (weak-typed). */ 717 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_xmcopyfunction)( 718 const void* in, const unsigned int* ldi, void* out, const unsigned int* ldo, ...); 719 720 /** Specialized function for transpose (weak-typed). */ 721 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_xtransfunction)( 722 const void* in, const unsigned int* ldi, void* out, const unsigned int* ldo); 723 724 /** Specialized function for packed GEMM (weak-typed). */ 725 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_pgemm_xfunction)( 726 const void* a, const void* b, void* c); 727 728 /** Specialized function for packed GEMM (weak-typed). */ 729 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_getrf_xfunction)( 730 const void* a, const void* b, void* c); 731 732 /** Specialized function for TRMM (weak-typed). */ 733 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_trmm_xfunction)( 734 const void* a, const void* b, void* c); 735 736 /** Specialized function for TRSM (weak-typed). */ 737 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void (*libxsmm_trsm_xfunction)( 738 const void* a, const void* b, void* c); 739 740 /** Structure to receive information about GEMM-kernels (libxsmm_get_mmkernel_info). */ 741 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_mmkernel_info { 742 /** Input/output data-type */ 743 libxsmm_gemm_precision iprecision, oprecision; 744 /** Prefetch strategy. */ 745 libxsmm_gemm_prefetch_type prefetch; 746 /** Leading dimensions. */ 747 unsigned int lda, ldb, ldc; 748 /** Extents/shape. */ 749 unsigned int m, n, k; 750 /** Set of flags. */ 751 int flags; 752 } libxsmm_mmkernel_info; 753 754 /** Structure to receive information about transpose-kernels (libxsmm_get_transkernel_info). */ 755 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_transkernel_info { 756 /** LD, M, and N. */ 757 unsigned int ldo, m, n; 758 /** Size of data element. */ 759 unsigned int typesize; 760 } libxsmm_transkernel_info; 761 762 /** Structure to receive information about matrix-copy kernels (libxsmm_get_mcopykernel_info). */ 763 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_mcopykernel_info { 764 /** LDx, M, and N. */ 765 unsigned int ldi, ldo, m, n; 766 /** Size of data element. */ 767 unsigned int typesize; 768 /** Boolean value. */ 769 int prefetch; 770 /** Set of flags. */ 771 int flags; 772 } libxsmm_mcopykernel_info; 773 774 /** Structure to receive information about matrix-eltw kernels (libxsmm_get_mcopykernel_info). */ 775 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_meltwkernel_info { 776 /** LDx, M, and N. */ 777 unsigned int ldi, ldo, m, n; 778 /** Size of data element. */ 779 unsigned int datatype; 780 /** Set of flags. */ 781 unsigned int flags; 782 /** Set of operation. */ 783 unsigned int operation; 784 } libxsmm_meltwkernel_info; 785 786 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_kernel_info { 787 libxsmm_kernel_kind kind; 788 /** Number of FLoating Point OperationS (FLOPS). */ 789 unsigned int nflops; 790 /** Code size (Bytes). */ 791 size_t code_size; 792 } libxsmm_kernel_info; 793 794 /** Structure to receive information about the code registry status (libxsmm_get_registry_info). */ 795 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_registry_info { 796 size_t capacity, size, nbytes, nstatic, ncache; 797 } libxsmm_registry_info; 798 799 #endif /*LIBXSMM_TYPEDEFS_H*/ 800 801