1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved. *
3 * This file is part of the LIBXSMM library. *
4 * *
5 * For information on the license, see the LICENSE file. *
6 * Further information: https://github.com/hfp/libxsmm/ *
7 * SPDX-License-Identifier: BSD-3-Clause *
8 ******************************************************************************/
9 /* Hans Pabst (Intel Corp.)
10 ******************************************************************************/
11 #ifndef LIBXSMM_MAIN_H
12 #define LIBXSMM_MAIN_H
13
14 #include <libxsmm.h>
15 /**
16 * TF includes src/libxsmm_main.h and uses LIBXSMM's sync primitives
17 * without including libxsmm_sync. However, libxsmm_sync.h shall be
18 * an explicit include separate from including libxsmm.h.
19 */
20 #include "libxsmm_sync.h"
21
22 /** Allow external definition to enable testing corner cases (exhausted registry space). */
23 #if !defined(LIBXSMM_CAPACITY_REGISTRY) /* must be POT */
24 # define LIBXSMM_CAPACITY_REGISTRY 131072
25 #endif
26 #if !defined(LIBXSMM_CAPACITY_CACHE) /* must be POT */
27 # define LIBXSMM_CAPACITY_CACHE 16
28 #endif
29
30 #if !defined(LIBXSMM_PAGE_MINSIZE)
31 # define LIBXSMM_PAGE_MINSIZE 4096 /* 4 KB */
32 #endif
33
34 #if !defined(LIBXSMM_NTHREADS_MAX)
35 # if (0 != LIBXSMM_SYNC)
36 # define LIBXSMM_NTHREADS_MAX 1024
37 # else
38 # define LIBXSMM_NTHREADS_MAX 1
39 # endif
40 #endif
41 /* code relies on LIBXSMM_NTHREADS_MAX or v/forks */
42 #if !defined(LIBXSMM_NTHREADS_USE) && 1
43 # define LIBXSMM_NTHREADS_USE
44 #endif
45 #if !defined(LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS)
46 # define LIBXSMM_MALLOC_SCRATCH_MAX_NPOOLS LIBXSMM_NTHREADS_MAX
47 #endif
48 #if !defined(LIBXSMM_MALLOC_SCRATCH_SCALE)
49 # define LIBXSMM_MALLOC_SCRATCH_SCALE 1.0
50 #endif
51 #if !defined(LIBXSMM_MALLOC_LIMIT)
52 # define LIBXSMM_MALLOC_LIMIT (2U << 20) /* 2 MB */
53 #endif
54 #if !defined(LIBXSMM_MALLOC_HOOK_REALLOC) && 1
55 # define LIBXSMM_MALLOC_HOOK_REALLOC
56 #endif
57 #if !defined(LIBXSMM_MALLOC_HOOK_CALLOC) && 1
58 # define LIBXSMM_MALLOC_HOOK_CALLOC
59 #endif
60 /* align even if interceptor is disabled at runtime */
61 #if !defined(LIBXSMM_MALLOC_ALIGN_ALL) && 1
62 # define LIBXSMM_MALLOC_ALIGN_ALL
63 #endif
64 #if !defined(LIBXSMM_MALLOC_INTERNAL_CALLER_ID)
65 # define LIBXSMM_MALLOC_INTERNAL_CALLER_ID ((uintptr_t)LIBXSMM_UNLIMITED)
66 #endif
67 #if !defined(LIBXSMM_MALLOC_INTERNAL_CALLER)
68 # define LIBXSMM_MALLOC_INTERNAL_CALLER ((const void*)(LIBXSMM_MALLOC_INTERNAL_CALLER_ID))
69 #endif
70
71 #if !defined(LIBXSMM_INTERCEPT_DYNAMIC) && defined(LIBXSMM_BUILD) && \
72 (defined(__GNUC__) || defined(_CRAYC)) && !defined(_WIN32) && !defined(__CYGWIN__) && \
73 !(defined(__APPLE__) && defined(__MACH__) && LIBXSMM_VERSION2(6, 1) >= \
74 LIBXSMM_VERSION2(__clang_major__, __clang_minor__))
75 # define LIBXSMM_INTERCEPT_DYNAMIC
76 #endif
77
78 #if !defined(LIBXSMM_MALLOC_HOOK_DYNAMIC) && defined(LIBXSMM_INTERCEPT_DYNAMIC) && \
79 defined(LIBXSMM_MALLOC) && (0 != LIBXSMM_MALLOC) && \
80 (!defined(_CRAYC) && !defined(__TRACE)) /* TODO */ && \
81 (defined(LIBXSMM_BUILD) && (1 < (LIBXSMM_BUILD))) /* GLIBC */
82 # define LIBXSMM_MALLOC_HOOK_DYNAMIC
83 #endif
84 #if !defined(LIBXSMM_MALLOC_HOOK_STATIC) && \
85 defined(LIBXSMM_MALLOC) && (0 != LIBXSMM_MALLOC) && \
86 (!defined(_WIN32)) /* TODO */ && \
87 (defined(LIBXSMM_BUILD) && (1 < (LIBXSMM_BUILD))) /* GLIBC */
88 # define LIBXSMM_MALLOC_HOOK_STATIC
89 #endif
90 #if !defined(LIBXSMM_DNN_CONVOLUTION_SETUP_USE_NTS) && \
91 defined(LIBXSMM_MALLOC_HOOK_DYNAMIC) && \
92 defined(LIBXSMM_MALLOC_ALIGN_ALL)
93 # define LIBXSMM_DNN_CONVOLUTION_SETUP_USE_NTS
94 #endif
95
96 #if defined(LIBXSMM_INTERCEPT_DYNAMIC)
97 # if defined(LIBXSMM_OFFLOAD_TARGET)
98 # pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
99 # endif
100 # include <dlfcn.h>
101 # if defined(LIBXSMM_OFFLOAD_TARGET)
102 # pragma offload_attribute(pop)
103 # endif
104 # if !defined(RTLD_NEXT)
105 # define LIBXSMM_RTLD_NEXT ((void*)-1l)
106 # else
107 # define LIBXSMM_RTLD_NEXT RTLD_NEXT
108 # endif
109 #endif
110
111 #if !defined(LIBXSMM_VERBOSITY_HIGH)
112 # define LIBXSMM_VERBOSITY_HIGH 3 /* secondary warning or info-verbosity */
113 #endif
114 #if !defined(LIBXSMM_VERBOSITY_WARN)
115 # define LIBXSMM_VERBOSITY_WARN ((LIBXSMM_VERBOSITY_HIGH) - LIBXSMM_MIN(1, LIBXSMM_VERBOSITY_HIGH))
116 #endif
117
118 #if !defined(LIBXSMM_LOCK)
119 # define LIBXSMM_LOCK LIBXSMM_LOCK_DEFAULT
120 #endif
121
122 #if !defined(LIBXSMM_EXT_MIN_NTASKS)
123 # define LIBXSMM_MIN_NTASKS(NT) 1
124 #endif
125 #if !defined(LIBXSMM_OVERHEAD)
126 # define LIBXSMM_OVERHEAD(NT) 0
127 #endif
128 #if !defined(LIBXSMM_NOOP_ARGS)
129 # define LIBXSMM_NOOP_ARGS(...)
130 #endif
131 #if !defined(LIBXSMM_NOOP)
132 # define LIBXSMM_NOOP
133 #endif
134
135 /** Check if M, N, K, or LDx fits into the descriptor. */
136 #if (0 != LIBXSMM_ILP64)
137 # define LIBXSMM_GEMM_NO_BYPASS_DIMS(M, N, K) (0xFFFFFFFF >= (M) && 0xFFFFFFFF >= (N) && 0xFFFFFFFF >= (K))
138 #else /* always fits */
139 # define LIBXSMM_GEMM_NO_BYPASS_DIMS(M, N, K) 1
140 #endif
141
142 #if defined(LIBXSMM_ASSERT) /* assert available */
143 # define LIBXSMM_GEMM_DESCRIPTOR_DIM_CHECK(M, N, K) LIBXSMM_ASSERT(LIBXSMM_GEMM_NO_BYPASS_DIMS(M, N, K))
144 #else
145 # define LIBXSMM_GEMM_DESCRIPTOR_DIM_CHECK(M, N, K)
146 #endif
147
148 #if defined(LIBXSMM_UNPACKED)
149 # define LIBXSMM_DESCRIPTOR_CLEAR_AUX(DST, SIZE) LIBXSMM_MEMSET127(DST, 0, SIZE)
150 #else
151 # define LIBXSMM_DESCRIPTOR_CLEAR_AUX(DST, SIZE)
152 #endif
153 #define LIBXSMM_DESCRIPTOR_CLEAR(BLOB) \
154 LIBXSMM_ASSERT((LIBXSMM_DESCRIPTOR_MAXSIZE) == sizeof(*(BLOB))); \
155 LIBXSMM_DESCRIPTOR_CLEAR_AUX(BLOB, LIBXSMM_DESCRIPTOR_MAXSIZE)
156
157 /** Low-level/internal GEMM descriptor initialization. */
158 #define LIBXSMM_GEMM_DESCRIPTOR(DESCRIPTOR, DATA_TYPE, FLAGS, M, N, K, LDA, LDB, LDC, ALPHA, BETA, PREFETCH) \
159 LIBXSMM_GEMM_DESCRIPTOR_DIM_CHECK(LDA, LDB, LDC); \
160 LIBXSMM_GEMM_DESCRIPTOR_DIM_CHECK(M, N, K); \
161 LIBXSMM_DESCRIPTOR_CLEAR_AUX(&(DESCRIPTOR), sizeof(DESCRIPTOR)); \
162 (DESCRIPTOR).datatype = (unsigned char)(DATA_TYPE); (DESCRIPTOR).prefetch = (unsigned char)(PREFETCH); \
163 (DESCRIPTOR).flags = (unsigned int)((FLAGS) \
164 /*| (LIBXSMM_NEQ(0, ALPHA) ? 0 : LIBXSMM_GEMM_FLAG_ALPHA_0)*/ \
165 | (LIBXSMM_NEQ(0, BETA) ? 0 : LIBXSMM_GEMM_FLAG_BETA_0)); \
166 (DESCRIPTOR).m = (unsigned int)(M); (DESCRIPTOR).n = (unsigned int)(N); (DESCRIPTOR).k = (unsigned int)(K); \
167 (DESCRIPTOR).lda = (unsigned int)(LDA); (DESCRIPTOR).ldb = (unsigned int)(LDB); (DESCRIPTOR).ldc = (unsigned int)(LDC); \
168 LIBXSMM_PAD((DESCRIPTOR).pad = 0) (DESCRIPTOR).c1 = 0; (DESCRIPTOR).c2 = 0; (DESCRIPTOR).c3 = 0; \
169 (DESCRIPTOR).meltw_ldx = 0; (DESCRIPTOR).meltw_ldy = 0; (DESCRIPTOR).meltw_ldz = 0; \
170 (DESCRIPTOR).meltw_datatype_aux = 0; (DESCRIPTOR).meltw_flags = 0; \
171 (DESCRIPTOR).meltw_operation = 0
172
173 /** Similar to LIBXSMM_GEMM_DESCRIPTOR, but separately taking the input-/output-precision. */
174 #define LIBXSMM_GEMM_DESCRIPTOR2(DESCRIPTOR, IPREC, OPREC, FLAGS, M, N, K, LDA, LDB, LDC, ALPHA, BETA, PREFETCH) \
175 LIBXSMM_GEMM_DESCRIPTOR(DESCRIPTOR, LIBXSMM_GETENUM(IPREC, OPREC), FLAGS, M, N, K, LDA, LDB, LDC, ALPHA, BETA, PREFETCH)
176
177 /** Declare and construct a GEMM descriptor. */
178 #define LIBXSMM_GEMM_DESCRIPTOR_TYPE(DESCRIPTOR, DATA_TYPE, FLAGS, M, N, K, LDA, LDB, LDC, ALPHA, BETA, PREFETCH) \
179 libxsmm_gemm_descriptor DESCRIPTOR; LIBXSMM_GEMM_DESCRIPTOR(DESCRIPTOR, DATA_TYPE, \
180 FLAGS, M, N, K, LDA, LDB, LDC, ALPHA, BETA, PREFETCH)
181
182 /** Similar to LIBXSMM_GEMM_DESCRIPTOR_TYPE, but separately taking the input-/output-precision. */
183 #define LIBXSMM_GEMM_DESCRIPTOR2_TYPE(DESCRIPTOR, IPREC, OPREC, FLAGS, M, N, K, LDA, LDB, LDC, ALPHA, BETA, PREFETCH) \
184 LIBXSMM_GEMM_DESCRIPTOR_TYPE(DESCRIPTOR, LIBXSMM_GETENUM(IPREC, OPREC), FLAGS, M, N, K, LDA, LDB, LDC, ALPHA, BETA, PREFETCH)
185
186 #define LIBXSMM_REGDESC_DEFAULT
187 #define LIBXSMM_REGDESC(START, MODIFIER) \
188 START libxsmm_gemm_descriptor MODIFIER gemm; \
189 START libxsmm_mcopy_descriptor MODIFIER mcopy; \
190 START libxsmm_meltw_descriptor MODIFIER meltw; \
191 START libxsmm_trans_descriptor MODIFIER trans; \
192 START libxsmm_pgemm_descriptor MODIFIER pgemm; \
193 START libxsmm_getrf_descriptor MODIFIER getrf; \
194 START libxsmm_trmm_descriptor MODIFIER trmm; \
195 START libxsmm_trsm_descriptor MODIFIER trsm
196
197
198 /**
199 * Packed structure, which stores the argument description of GEMM routines.
200 * The size of the structure is padded to LIBXSMM_DESCRIPTOR_MAXSIZE.
201 */
LIBXSMM_PACKED(struct LIBXSMM_RETARGETABLE)202 LIBXSMM_EXTERN_C LIBXSMM_PACKED(struct LIBXSMM_RETARGETABLE) libxsmm_gemm_descriptor {
203 /** Extents of the matrix. */
204 unsigned int m, n, k;
205 /** Leading dimensions. */
206 unsigned int lda, ldb, ldc;
207 /** Set of flags. */
208 unsigned int flags;
209 /** Prefetch strategy. */
210 unsigned char prefetch;
211 /** Denotes the data-type. */
212 unsigned char datatype;
213 /** Ignored entry. */
214 LIBXSMM_PAD(unsigned char pad)
215 /** multipurpose 64bit field, currently used for: a) stride_a in brgemm */
216 unsigned long long c1;
217 /** multipurpose 64bit field, currently used for: a) stride_b in brgemm */
218 unsigned long long c2;
219 /** multipurpose 8bit field, currently used for: a) unroll hint in brgemm */
220 unsigned char c3;
221 /** LDx, LDy, LDz, additional meltw LDs */
222 unsigned int meltw_ldx, meltw_ldy, meltw_ldz;
223 /** Size of data element. */
224 unsigned char meltw_datatype_aux;
225 /** Set of flags */
226 unsigned char meltw_flags;
227 /** operation specifier */
228 unsigned char meltw_operation;
229 };
230
231 /** Packed structure storing the matcopy argument description. */
LIBXSMM_PACKED(struct LIBXSMM_RETARGETABLE)232 LIBXSMM_EXTERN_C LIBXSMM_PACKED(struct LIBXSMM_RETARGETABLE) libxsmm_mcopy_descriptor {
233 /** LDx, M, and N. */
234 unsigned int m, n, ldi, ldo;
235 /** Size of data element. */
236 unsigned char typesize;
237 /** Level of unrolling. */
238 unsigned char unroll_level;
239 /** Boolean value (@TODO fix this). */
240 unsigned char prefetch;
241 /** Set of flags. */
242 unsigned char flags;
243 };
244
245 /** Packed structure storing the mateltw argument description. */
LIBXSMM_PACKED(struct LIBXSMM_RETARGETABLE)246 LIBXSMM_EXTERN_C LIBXSMM_PACKED(struct LIBXSMM_RETARGETABLE) libxsmm_meltw_descriptor {
247 /** LDx, M, and N. */
248 unsigned int m, n, ldi, ldo, ldx, ldy;
249 /** Size of data element. */
250 unsigned char datatype;
251 unsigned char datatype2;
252 /** Set of flags */
253 unsigned char flags;
254 /** operation specifier */
255 unsigned char operation;
256 };
257
258 /** Packed structure storing the transpose argument description. */
LIBXSMM_PACKED(struct LIBXSMM_RETARGETABLE)259 LIBXSMM_EXTERN_C LIBXSMM_PACKED(struct LIBXSMM_RETARGETABLE) libxsmm_trans_descriptor {
260 /** LD, M, and N. */
261 unsigned int m, n, ldo;
262 /** Size of data element. */
263 unsigned char typesize;
264 };
265
266 /** Packed structure storing arguments of packed GEMM. */
LIBXSMM_PACKED(struct LIBXSMM_RETARGETABLE)267 LIBXSMM_EXTERN_C LIBXSMM_PACKED(struct LIBXSMM_RETARGETABLE) libxsmm_pgemm_descriptor {
268 unsigned int m, n, k, lda, ldb, ldc;
269 unsigned char typesize;
270 unsigned char layout;
271 char transa, transb;
272 char alpha_val;
273 };
274
275 /** Packed structure storing arguments of packed GETRF. */
LIBXSMM_PACKED(struct LIBXSMM_RETARGETABLE)276 LIBXSMM_EXTERN_C LIBXSMM_PACKED(struct LIBXSMM_RETARGETABLE) libxsmm_getrf_descriptor {
277 unsigned int m, n, lda;
278 unsigned char typesize;
279 unsigned char layout;
280 };
281
282 /** Packed structure storing arguments of packed TRSM. */
LIBXSMM_PACKED(struct LIBXSMM_RETARGETABLE)283 LIBXSMM_EXTERN_C LIBXSMM_PACKED(struct LIBXSMM_RETARGETABLE) libxsmm_trmm_descriptor {
284 union { double d; float s; } alpha;
285 unsigned int m, n, lda, ldb;
286 unsigned char typesize;
287 unsigned char layout;
288 char diag, side, uplo;
289 char transa;
290 };
291
292 /** Packed structure storing arguments of packed TRSM. */
LIBXSMM_PACKED(struct LIBXSMM_RETARGETABLE)293 LIBXSMM_EXTERN_C LIBXSMM_PACKED(struct LIBXSMM_RETARGETABLE) libxsmm_trsm_descriptor {
294 union { double d; float s; } alpha;
295 unsigned int m, n, lda, ldb;
296 unsigned char typesize;
297 unsigned char layout;
298 char diag, side, uplo;
299 char transa;
300 };
301
302 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE LIBXSMM_MAY_ALIAS libxsmm_csr_soa_descriptor {
303 const libxsmm_gemm_descriptor* gemm;
304 const unsigned int* row_ptr;
305 const unsigned int* column_idx;
306 const void* values;
307 unsigned int packed_width;
308 } libxsmm_csr_soa_descriptor;
309
310 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE LIBXSMM_MAY_ALIAS libxsmm_csc_soa_descriptor {
311 const libxsmm_gemm_descriptor* gemm;
312 const unsigned int* column_ptr;
313 const unsigned int* row_idx;
314 const void* values;
315 unsigned int packed_width;
316 } libxsmm_csc_soa_descriptor;
317
318 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE LIBXSMM_MAY_ALIAS libxsmm_pgemm_ac_rm_descriptor {
319 const libxsmm_gemm_descriptor* gemm;
320 unsigned int packed_width;
321 } libxsmm_pgemm_ac_rm_descriptor;
322
323 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE LIBXSMM_MAY_ALIAS libxsmm_pgemm_bc_rm_descriptor {
324 const libxsmm_gemm_descriptor* gemm;
325 unsigned int packed_width;
326 } libxsmm_pgemm_bc_rm_descriptor;
327
328 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE LIBXSMM_MAY_ALIAS libxsmm_csr_reg_descriptor {
329 const libxsmm_gemm_descriptor* gemm;
330 const unsigned int* row_ptr;
331 const unsigned int* column_idx;
332 const void* values;
333 } libxsmm_csr_reg_descriptor;
334
335 LIBXSMM_EXTERN_C typedef union LIBXSMM_RETARGETABLE libxsmm_code_pointer {
336 void (*ptr_fn)(LIBXSMM_VARIADIC);
337 const void* ptr_const;
338 void* ptr;
339 uintptr_t uval;
340 intptr_t ival;
341 libxsmm_xmmfunction xgemm; /* GEMM: smm, dmm, wimm, or void-function */
342 libxsmm_xmcopyfunction xmatcopy;
343 libxsmm_xmeltwfunction xmateltw;
344 libxsmm_xtransfunction xtrans;
345 libxsmm_pgemm_xfunction xpgemm;
346 libxsmm_getrf_xfunction xgetrf;
347 libxsmm_trmm_xfunction xtrmm;
348 libxsmm_trsm_xfunction xtrsm;
349 } libxsmm_code_pointer;
350
351 /** Structure which describes all tensors in LIBXSMM's DNN module */
352 LIBXSMM_EXTERN_C struct LIBXSMM_RETARGETABLE libxsmm_dnn_tensor {
353 libxsmm_dnn_tensor_datalayout* layout; /* data-layout descriptor */
354 void* data; /* pointer to data */
355 unsigned char scf; /* fix point scaling factor for this tensor */
356 };
357
358 /* Structure to record segment in stream of code */
359 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE segment_t {
360 int segment_type;
361 int n_convs;
362 int aux_index;
363 } segment_t;
364
365 LIBXSMM_EXTERN_C struct LIBXSMM_RETARGETABLE libxsmm_dnn_layer {
366 libxsmm_dnn_datatype datatype_in;
367 libxsmm_dnn_datatype datatype_out;
368 libxsmm_dnn_conv_desc desc;
369 libxsmm_dnn_conv_algo algo;
370 libxsmm_dnn_tensor_format buffer_format;
371 libxsmm_dnn_tensor_format filter_format;
372 libxsmm_dnn_conv_fuse_op fuse_ops;
373 libxsmm_dnn_conv_option options;
374
375 /* additional size for internal data types */
376 int ifhp;
377 int ifwp;
378 int ofh;
379 int ofw;
380 int ofhp;
381 int ofwp;
382 int ifmblock;
383 int ofmblock;
384 int blocksifm;
385 int blocksofm;
386 int fwd_ofw_rb;
387 int fwd_ofh_rb;
388 int bwd_ofw_rb;
389 int bwd_ofh_rb;
390 int upd_ofw_rb;
391 int upd_ofh_rb;
392 int fm_lp_block; /* additional blocking for low precision datatypes of feature maps */
393 int blocksifm_blocking;
394 int blocksofm_blocking;
395 int avoid_acc_load;
396 int avoid_acc_load_bwd;
397 int pack_input;
398 int pack_input_bwd;
399 int spread_input_bwd;
400 int weight_copies;
401 int loop_order;
402 int use_ofm_parallelization;
403 int use_ifm_parallelization;
404 int avoid_fmas_in_rim;
405 int upd_use_batchreduce;
406 int upd_pack_input;
407 int upd_loop_order;
408 int upd_linearized_tasklist;
409 int upd_avoid_rim_fmas;
410 int fwd_flags;
411 int shuffle_filter_accesses;
412 int use_fallback_fwd_loops;
413 int use_fallback_bwd_loops;
414 int input_pixels;
415 int output_pixels;
416 int n_used_pixels;
417 int pixel_blocking;
418 int use_intermediate_f32_wt_tensor;
419 int upd_linearized_pixels;
420 int ifwp_extended;
421 int ofwp_extended;
422 int batchreduce_h_pixels;
423 int on_the_fly_input_packing;
424 int upd_pack_input_upfront;
425 int use_hybrid_imgofm_parallelization;
426 int compute_pixels;
427 int upd_trans_w_only;
428 int fwd_padding_copy;
429 int upd_padding_copy;
430 int block_fwd_oj;
431 int block_fwd_ifm;
432 int block_fwd_ofm;
433 int block_bwd_oj;
434 int block_bwd_ifm;
435 int block_bwd_ofm;
436 int block_upd_ifm;
437 int block_upd_ofm;
438
439 libxsmm_xtransfunction tr_kernel;
440 libxsmm_meltwfunction_cvtfp32bf16 fwd_cvtfp32bf16_kernel;
441
442 /* internal data representation */
443 libxsmm_dnn_tensor* reg_input;
444 libxsmm_dnn_tensor* reg_output;
445 libxsmm_dnn_tensor* reg_filter;
446 libxsmm_dnn_tensor* grad_input;
447 libxsmm_dnn_tensor* grad_output;
448 libxsmm_dnn_tensor* grad_filter;
449 libxsmm_dnn_tensor* reg_bias;
450 libxsmm_dnn_tensor* grad_bias;
451 /* internal data representations for copies of tensors */
452 libxsmm_dnn_tensor* reg_input_tr;
453 libxsmm_dnn_tensor* reg_filter_tr;
454 /* batchnorm stats */
455 libxsmm_dnn_tensor* batch_stats;
456 /* maxstats used in low-precision kernels */
457 libxsmm_dnn_tensor* maxstats_fwd;
458 libxsmm_dnn_tensor* maxstats_bwd;
459 libxsmm_dnn_tensor* maxstats_upd;
460
461 /* barrier */
462 libxsmm_barrier* barrier;
463
464 /* scratch */
465 size_t fwd_packing_padding_scratch_size;
466 size_t fwd_lp_output_full_scratch_size;
467 size_t fwd_lp_output_block_scratch_size;
468 size_t fwd_packing_padding_scratch_offset;
469 size_t fwd_lp_output_full_scratch_offset;
470 size_t fwd_lp_output_block_scratch_offset;
471 size_t fwd_scratch_size;
472
473 size_t bwd_filter_trans_scratch_size;
474 size_t bwd_packing_padding_scratch_size;
475 size_t bwd_lp_input_full_scratch_size;
476 size_t bwd_filter_trans_scratch_offset;
477 size_t bwd_packing_padding_scratch_offset;
478 size_t bwd_lp_input_full_scratch_offset;
479 size_t bwd_scratch_size;
480
481 size_t upd_packing_padding_scratch_size;
482 size_t upd_lp_output_full_scratch_size;
483 size_t upd_lp_input_full_scratch_size;
484 size_t upd_filter_scratch_size;
485 size_t upd_lp_filter_full_scratch_size;
486 size_t upd_packing_padding_scratch_offset;
487 size_t upd_lp_output_full_scratch_offset;
488 size_t upd_lp_input_full_scratch_offset;
489 size_t upd_lp_filter_full_scratch_offset;
490 size_t upd_filter_scratch_offset;
491 size_t upd_scratch_size;
492
493 void* scratch;
494 size_t scratch_size;
495
496 libxsmm_code_pointer gemm_fwd; /* ability to hoist forward GEMMs */
497 libxsmm_code_pointer gemm_fwd2; /* ability to hoist forward GEMMs */
498
499 unsigned long long *A_offsets;
500 unsigned long long *B_offsets;
501
502 /* JIT-generated convolution code */
503 libxsmm_code_pointer code_fwd[3];
504 libxsmm_code_pointer code_bwd[3];
505 libxsmm_code_pointer code_upd[2];
506
507 libxsmm_code_pointer matcopy_fwd[4];
508 libxsmm_code_pointer matcopy_bwd[4];
509 libxsmm_code_pointer matcopy_upd[3];
510 };
511
512 LIBXSMM_EXTERN_C struct LIBXSMM_RETARGETABLE libxsmm_dnn_fusedbatchnorm {
513 libxsmm_dnn_fusedbatchnorm_desc desc;
514 libxsmm_dnn_tensor* reg_input; /* input tensor */
515 libxsmm_dnn_tensor* reg_output; /* output tensor */
516 libxsmm_dnn_tensor* grad_input; /* grad input tensor */
517 libxsmm_dnn_tensor* grad_output; /* grad output tensor */
518 libxsmm_dnn_tensor* reg_add; /* elementwise tensor */
519 libxsmm_dnn_tensor* grad_add; /* grad elementwise tensor */
520 libxsmm_dnn_tensor* reg_beta; /* beta tensor */
521 libxsmm_dnn_tensor* reg_gamma; /* gamma tensor */
522 libxsmm_dnn_tensor* grad_beta; /* grad beta tensor */
523 libxsmm_dnn_tensor* grad_gamma; /* grad gamma tensor */
524 libxsmm_dnn_tensor* expvalue; /* expected value */
525 libxsmm_dnn_tensor* rcpstddev; /* reciprocal of standard derivation */
526 libxsmm_dnn_tensor* variance; /* variance */
527 libxsmm_dnn_tensor* relumask; /* relumask */
528 libxsmm_barrier* barrier; /* barrier */
529 int ifmblock;
530 int ofmblock;
531 int blocksifm;
532 int blocksofm;
533 size_t scratch_size;
534 void* scratch;
535 };
536
537 LIBXSMM_EXTERN_C struct LIBXSMM_RETARGETABLE libxsmm_dnn_softmaxloss {
538 libxsmm_dnn_softmaxloss_desc desc;
539 libxsmm_dnn_tensor* reg_input; /* input tensor */
540 libxsmm_dnn_tensor* reg_output; /* output tensor */
541 libxsmm_dnn_tensor* grad_input; /* grad input tensor */
542 libxsmm_dnn_tensor* label; /* labels tensor */
543 libxsmm_barrier* barrier; /* barrier */
544 int bc;
545 int Bc;
546 int bn;
547 int Bn;
548 float loss;
549 size_t scratch_size;
550 void* scratch;
551 };
552
553 LIBXSMM_EXTERN_C struct LIBXSMM_RETARGETABLE libxsmm_dnn_optimizer {
554 libxsmm_dnn_optimizer_desc desc;
555 libxsmm_dnn_tensor* reg_filter; /* filter tensor */
556 libxsmm_dnn_tensor* grad_filter; /* grad filter tensor */
557 libxsmm_dnn_tensor* master_filter; /* master filter tensor */
558 libxsmm_barrier* barrier; /* barrier */
559 int bc;
560 int Bc;
561 int bk;
562 int Bk;
563 int fm_lp_block;
564 size_t scratch_size;
565 void* scratch;
566 };
567
568 LIBXSMM_EXTERN_C struct LIBXSMM_RETARGETABLE libxsmm_dnn_fusedgroupnorm {
569 libxsmm_dnn_fusedgroupnorm_desc desc;
570 libxsmm_dnn_tensor* reg_input; /* input tensor */
571 libxsmm_dnn_tensor* reg_output; /* output tensor */
572 libxsmm_dnn_tensor* grad_input; /* grad input tensor */
573 libxsmm_dnn_tensor* grad_output; /* grad output tensor */
574 libxsmm_dnn_tensor* reg_add; /* elementwise tensor */
575 libxsmm_dnn_tensor* grad_add; /* grad elementwise tensor */
576 libxsmm_dnn_tensor* reg_beta; /* beta tensor */
577 libxsmm_dnn_tensor* reg_gamma; /* gamma tensor */
578 libxsmm_dnn_tensor* grad_beta; /* grad beta tensor */
579 libxsmm_dnn_tensor* grad_gamma; /* grad gamma tensor */
580 libxsmm_dnn_tensor* expvalue; /* expected value */
581 libxsmm_dnn_tensor* rcpstddev; /* reciprocal of standard derivation */
582 libxsmm_dnn_tensor* variance; /* variance */
583 libxsmm_dnn_tensor* relumask; /* relumask */
584 libxsmm_barrier* barrier; /* barrier */
585 int ifmblock;
586 int ofmblock;
587 int blocksifm;
588 int blocksofm;
589 size_t scratch_size;
590 void* scratch;
591 };
592
593 LIBXSMM_EXTERN_C struct LIBXSMM_RETARGETABLE libxsmm_dnn_fullyconnected {
594 libxsmm_dnn_fullyconnected_desc desc;
595 libxsmm_dnn_tensor* reg_input; /* input tensor */
596 libxsmm_dnn_tensor* reg_output; /* output tensor */
597 libxsmm_dnn_tensor* grad_input; /* grad input tensor */
598 libxsmm_dnn_tensor* grad_output; /* grad output tensor */
599 libxsmm_dnn_tensor* reg_filter; /* filter tensor */
600 libxsmm_dnn_tensor* grad_filter; /* grad filter tensor */
601 libxsmm_dnn_tensor* reg_bias; /* bias tensor */
602 libxsmm_dnn_tensor* grad_bias; /* grad bais tensor */
603 libxsmm_dnn_tensor* relumask; /* relumask */
604 libxsmm_barrier* barrier; /* barrier */
605 int ifmblock;
606 int ofmblock;
607 int blocksifm;
608 int blocksofm;
609 /* Parameters to tune/specialize FC algorithms */
610 int fwd_2d_blocking;
611 int bwd_2d_blocking;
612 int upd_2d_blocking;
613 int fwd_bf;
614 int bwd_bf;
615 int upd_bf;
616 int fwd_row_teams;
617 int fwd_column_teams;
618 int bwd_row_teams;
619 int bwd_column_teams;
620 int upd_row_teams;
621 int upd_column_teams;
622 int ifm_subtasks;
623 int ofm_subtasks;
624
625 int fm_lp_block;
626 int bn;
627 int bk;
628 int bc;
629 size_t scratch_size;
630 size_t doutput_scratch_mark;
631 void* scratch;
632
633 libxsmm_xtransfunction tr_kernel;
634 libxsmm_code_pointer gemm_fwd; /* ability to hoist forward GEMMs */
635 libxsmm_code_pointer gemm_fwd2; /* ability to hoist forward GEMMs */
636 libxsmm_code_pointer gemm_fwd3; /* ability to hoist forward GEMMs */
637 libxsmm_code_pointer gemm_bwd; /* ability to hoist backward GEMMs */
638 libxsmm_code_pointer gemm_bwd2; /* ability to hoist backward GEMMs */
639 libxsmm_code_pointer gemm_upd; /* ability to hoist update GEMMs */
640 libxsmm_code_pointer gemm_upd2; /* ability to hoist update GEMMs */
641 };
642
643 LIBXSMM_EXTERN_C struct LIBXSMM_RETARGETABLE libxsmm_dnn_pooling {
644 libxsmm_dnn_pooling_desc desc;
645 libxsmm_dnn_tensor* reg_input; /* input tensor */
646 libxsmm_dnn_tensor* reg_output; /* output tensor */
647 libxsmm_dnn_tensor* grad_input; /* grad input tensor */
648 libxsmm_dnn_tensor* grad_output; /* grad output tensor */
649 libxsmm_dnn_tensor* mask; /* elementwise tensor */
650 libxsmm_barrier* barrier; /* barrier */
651 int ifmblock;
652 int ofmblock;
653 int blocksifm;
654 int blocksofm;
655 int ofh;
656 int ofw;
657 size_t scratch_size;
658 void* scratch;
659 };
660
661 LIBXSMM_EXTERN_C struct LIBXSMM_RETARGETABLE libxsmm_dnn_rnncell {
662 libxsmm_dnn_rnncell_desc desc;
663 libxsmm_blasint T; /* sequence length, must be smaller than max sequence length in desc */
664 libxsmm_blasint bk;
665 libxsmm_blasint bn;
666 libxsmm_blasint bc;
667 libxsmm_blasint lpb;
668
669 /* external tensors */
670 libxsmm_dnn_tensor* xt;
671 libxsmm_dnn_tensor* csp;
672 libxsmm_dnn_tensor* hp;
673 libxsmm_dnn_tensor* w;
674 libxsmm_dnn_tensor* wt;
675 libxsmm_dnn_tensor* r;
676 libxsmm_dnn_tensor* rt;
677 libxsmm_dnn_tensor* b;
678 libxsmm_dnn_tensor* cst;
679 libxsmm_dnn_tensor* ht;
680 libxsmm_dnn_tensor* dxt;
681 libxsmm_dnn_tensor* dcsp;
682 libxsmm_dnn_tensor* dhp;
683 libxsmm_dnn_tensor* dw;
684 libxsmm_dnn_tensor* dr;
685 libxsmm_dnn_tensor* db;
686 libxsmm_dnn_tensor* dcs;
687 libxsmm_dnn_tensor* dht;
688 libxsmm_dnn_tensor* it;
689 libxsmm_dnn_tensor* ft;
690 libxsmm_dnn_tensor* ot;
691 libxsmm_dnn_tensor* cit;
692 libxsmm_dnn_tensor* cot;
693 float forget_bias;
694 /* internal state */
695 void* internal_z;
696 /* scratch pointers */
697 void* scratch_base;
698 void* scratch_wT;
699 void* scratch_rT;
700 void* scratch_w;
701 void* scratch_r;
702 void* scratch_xT;
703 void* scratch_hT;
704 void* scratch_deltat;
705 void* scratch_di;
706 void* scratch_df;
707 void* scratch_do;
708 void* scratch_dci;
709 void* scratch_diB;
710 void* scratch_dfB;
711 void* scratch_dpB;
712 void* scratch_dciB;
713 void* scratch_dx;
714 void* scratch_dhp;
715 void* scratch_db;
716 void* scratch_t1;
717 void* scratch_t2;
718 void* csp_scratch;
719 void* cst_scratch;
720 void* ht_scratch;
721 void* it_scratch;
722 void* ft_scratch;
723 void* ot_scratch;
724 void* cit_scratch;
725 void* cot_scratch;
726
727 /* Ability to hoist GEMMs */
728 libxsmm_bsmmfunction_reducebatch_strd fwd_kernela;
729 libxsmm_bsmmfunction_reducebatch_strd fwd_kernelb;
730 libxsmm_bsmmfunction_reducebatch_strd bwdupd_kernela;
731 libxsmm_bsmmfunction_reducebatch_strd bwdupd_kernelb;
732 libxsmm_bsmmfunction_reducebatch_strd bwdupd_kernelc;
733 libxsmm_bsmmfunction_reducebatch_strd bwdupd_kerneld;
734
735 libxsmm_barrier* barrier; /* barrier */
736 };
737
738 struct LIBXSMM_RETARGETABLE libxsmm_dfsspmdm {
739 int M;
740 int N;
741 int K;
742 int ldb;
743 int ldc;
744 int N_chunksize;
745 unsigned int* permute_operands;
746 double* a_dense;
747 libxsmm_dmmfunction kernel;
748 };
749
750 struct LIBXSMM_RETARGETABLE libxsmm_sfsspmdm {
751 int M;
752 int N;
753 int K;
754 int ldb;
755 int ldc;
756 int N_chunksize;
757 unsigned int* permute_operands;
758 float* a_dense;
759 libxsmm_smmfunction kernel;
760 };
761
762 typedef enum libxsmm_build_kind {
763 LIBXSMM_BUILD_KIND_GEMM = LIBXSMM_KERNEL_KIND_MATMUL,
764 LIBXSMM_BUILD_KIND_MCOPY = LIBXSMM_KERNEL_KIND_MCOPY,
765 LIBXSMM_BUILD_KIND_MELTW = LIBXSMM_KERNEL_KIND_MELTW,
766 LIBXSMM_BUILD_KIND_TRANS = LIBXSMM_KERNEL_KIND_TRANS,
767 LIBXSMM_BUILD_KIND_PGEMM = LIBXSMM_KERNEL_KIND_PGEMM,
768 LIBXSMM_BUILD_KIND_GETRF = LIBXSMM_KERNEL_KIND_GETRF,
769 LIBXSMM_BUILD_KIND_TRMM = LIBXSMM_KERNEL_KIND_TRMM,
770 LIBXSMM_BUILD_KIND_TRSM = LIBXSMM_KERNEL_KIND_TRSM,
771 LIBXSMM_BUILD_KIND_USER = LIBXSMM_KERNEL_KIND_USER,
772 LIBXSMM_BUILD_KIND_PGEMMRMAC = LIBXSMM_KERNEL_UNREGISTERED,
773 LIBXSMM_BUILD_KIND_PGEMMRMBC,
774 LIBXSMM_BUILD_KIND_SRSOA,
775 LIBXSMM_BUILD_KIND_SCSOA,
776 LIBXSMM_BUILD_KIND_SREG
777 } libxsmm_build_kind;
778
779 /** Integral type (libxsmm_kernel_kind, libxsmm_build_kind). */
780 #if defined(LIBXSMM_UNPACKED)
781 typedef size_t libxsmm_descriptor_kind;
782 #else
783 typedef unsigned char libxsmm_descriptor_kind;
784 #endif
785
786 /** All descriptor types, which are valid for code-registration. */
787 LIBXSMM_EXTERN_C typedef union LIBXSMM_RETARGETABLE libxsmm_descriptor {
788 char data[LIBXSMM_DESCRIPTOR_MAXSIZE];
789 libxsmm_descriptor_kind kind; /* kind: must be the first member */
790 LIBXSMM_REGDESC(LIBXSMM_PACKED(struct) { libxsmm_descriptor_kind /*repeated kind*/ pad; , desc; });
LIBXSMM_PACKED(struct)791 LIBXSMM_PACKED(struct) { libxsmm_descriptor_kind /*repeated kind*/ pad; char desc[1]; } user;
792 } libxsmm_descriptor;
793
794 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_build_request {
795 union {
796 const void* ptr; /* raw content */
797 LIBXSMM_REGDESC(LIBXSMM_REGDESC_DEFAULT, const*);
798 const libxsmm_csr_soa_descriptor* srsoa;
799 const libxsmm_csc_soa_descriptor* scsoa;
800 const libxsmm_pgemm_ac_rm_descriptor* pgemmacrm;
801 const libxsmm_pgemm_bc_rm_descriptor* pgemmbcrm;
802 const libxsmm_csr_reg_descriptor* sreg;
803 } descriptor;
804 libxsmm_build_kind kind;
805 /* used by user-kind */
806 size_t user_size;
807 } libxsmm_build_request;
808
809 typedef enum libxsmm_malloc_flags {
810 LIBXSMM_MALLOC_FLAG_DEFAULT = 0,
811 LIBXSMM_MALLOC_FLAG_SCRATCH = 1,
812 LIBXSMM_MALLOC_FLAG_PRIVATE = 2,
813 LIBXSMM_MALLOC_FLAG_REALLOC = 4,
814 LIBXSMM_MALLOC_FLAG_PHUGE = 8,
815 LIBXSMM_MALLOC_FLAG_PLOCK = 16,
816 LIBXSMM_MALLOC_FLAG_MMAP = 32,
817 LIBXSMM_MALLOC_FLAG_R = 64,
818 LIBXSMM_MALLOC_FLAG_W = 128,
819 LIBXSMM_MALLOC_FLAG_X = 256,
820 LIBXSMM_MALLOC_FLAG_RW = LIBXSMM_MALLOC_FLAG_R | LIBXSMM_MALLOC_FLAG_W,
821 LIBXSMM_MALLOC_FLAG_WX = LIBXSMM_MALLOC_FLAG_X | LIBXSMM_MALLOC_FLAG_W,
822 LIBXSMM_MALLOC_FLAG_RWX = LIBXSMM_MALLOC_FLAG_X | LIBXSMM_MALLOC_FLAG_RW,
823 LIBXSMM_MALLOC_FLAG_VALID = LIBXSMM_MALLOC_FLAG_SCRATCH |
824 LIBXSMM_MALLOC_FLAG_PRIVATE | LIBXSMM_MALLOC_FLAG_REALLOC |
825 LIBXSMM_MALLOC_FLAG_PHUGE | LIBXSMM_MALLOC_FLAG_PLOCK |
826 LIBXSMM_MALLOC_FLAG_MMAP | LIBXSMM_MALLOC_FLAG_RWX
827 } libxsmm_malloc_flags;
828
829 LIBXSMM_EXTERN_C typedef LIBXSMM_RETARGETABLE void* (*libxsmm_realloc_fun)(void* /*ptr*/, size_t /*size*/);
830
831 #if defined(LIBXSMM_MALLOC_HOOK_DYNAMIC)
832 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_malloc_fntype {
833 union { const void* dlsym; void* (*ptr)(size_t, size_t); } alignmem;
834 union { const void* dlsym; void* (*ptr)(size_t, size_t); } memalign;
835 union { const void* dlsym; libxsmm_malloc_fun ptr; } malloc;
836 # if defined(LIBXSMM_MALLOC_HOOK_CALLOC)
837 union { const void* dlsym; void* (*ptr)(size_t, size_t); } calloc;
838 # endif
839 # if defined(LIBXSMM_MALLOC_HOOK_REALLOC)
840 union { const void* dlsym; libxsmm_realloc_fun ptr; } realloc;
841 # endif
842 union { const void* dlsym; libxsmm_free_fun ptr; } free;
843 } libxsmm_malloc_fntype;
844 LIBXSMM_APIVAR_PRIVATE(libxsmm_malloc_fntype libxsmm_malloc_fn);
845 #endif
846
847 #if (defined(LIBXSMM_BUILD) && (1 < (LIBXSMM_BUILD)))
848 /* prototypes for GLIBC internal implementation */
849 LIBXSMM_EXTERN_C LIBXSMM_RETARGETABLE void* __libc_memalign(size_t alignment, size_t size);
850 LIBXSMM_EXTERN_C LIBXSMM_RETARGETABLE void* __libc_malloc(size_t size);
851 #if defined(LIBXSMM_MALLOC_HOOK_CALLOC)
852 LIBXSMM_EXTERN_C LIBXSMM_RETARGETABLE void* __libc_calloc(size_t num, size_t size);
853 #endif
854 #if defined(LIBXSMM_MALLOC_HOOK_REALLOC)
855 LIBXSMM_EXTERN_C LIBXSMM_RETARGETABLE void* __libc_realloc(void* ptr, size_t size);
856 #endif
857 LIBXSMM_EXTERN_C LIBXSMM_RETARGETABLE void __libc_free(void* ptr);
858 #endif /*(defined(LIBXSMM_BUILD) && (1 < (LIBXSMM_BUILD)))*/
859
860 LIBXSMM_API_INTERN void* libxsmm_memalign_internal(size_t alignment, size_t size);
861
862 /* See https://sourceware.org/binutils/docs-2.34/ld/Options.html#index-_002d_002dwrap_003dsymbol */
863 LIBXSMM_API_INTERN LIBXSMM_ATTRIBUTE_WEAK void* __real_memalign(size_t alignment, size_t size);
864 LIBXSMM_API_INTERN LIBXSMM_ATTRIBUTE_WEAK void* __real_malloc(size_t size);
865 #if defined(LIBXSMM_MALLOC_HOOK_CALLOC)
866 LIBXSMM_API_INTERN LIBXSMM_ATTRIBUTE_WEAK void* __real_calloc(size_t num, size_t size);
867 #endif
868 #if defined(LIBXSMM_MALLOC_HOOK_REALLOC)
869 LIBXSMM_API_INTERN LIBXSMM_ATTRIBUTE_WEAK void* __real_realloc(void* ptr, size_t size);
870 #endif
871 LIBXSMM_API_INTERN LIBXSMM_ATTRIBUTE_WEAK void __real_free(void* ptr);
872
873 /** Retrieve internal information about a buffer (default memory domain). */
874 LIBXSMM_API int libxsmm_get_malloc_xinfo(const void* memory, size_t* size, int* flags, void** extra);
875
876 /** Initializes malloc hooks and other internals. */
877 LIBXSMM_API_INTERN void libxsmm_malloc_init(void);
878 LIBXSMM_API_INTERN void libxsmm_malloc_finalize(void);
879
880 /** Calculates an alignment depending on supposedly allocated size; alignment can be zero ("auto"). */
881 LIBXSMM_API_INTERN size_t libxsmm_alignment(size_t size, size_t alignment);
882
883 /** Same as libxsmm_set_default_allocator, but takes a lock (can be NULL). */
884 LIBXSMM_API_INTERN int libxsmm_xset_default_allocator(LIBXSMM_LOCK_TYPE(LIBXSMM_LOCK)* lock,
885 const void* context, libxsmm_malloc_function malloc_fn, libxsmm_free_function free_fn);
886 /** Same as libxsmm_get_default_allocator, but takes a lock (can be NULL). */
887 LIBXSMM_API_INTERN int libxsmm_xget_default_allocator(LIBXSMM_LOCK_TYPE(LIBXSMM_LOCK)* lock,
888 const void** context, libxsmm_malloc_function* malloc_fn, libxsmm_free_function* free_fn);
889
890 /** Same as libxsmm_set_scratch_allocator, but takes a lock (can be NULL). */
891 LIBXSMM_API_INTERN int libxsmm_xset_scratch_allocator(LIBXSMM_LOCK_TYPE(LIBXSMM_LOCK)* lock,
892 const void* context, libxsmm_malloc_function malloc_fn, libxsmm_free_function free_fn);
893 /** Same as libxsmm_get_scratch_allocator, but takes a lock (can be NULL). */
894 LIBXSMM_API_INTERN int libxsmm_xget_scratch_allocator(LIBXSMM_LOCK_TYPE(LIBXSMM_LOCK)* lock,
895 const void** context, libxsmm_malloc_function* malloc_fn, libxsmm_free_function* free_fn);
896
897 /**
898 * Attribute memory allocation and protect with only the necessary flags.
899 * This procedure is expected to run only one time per buffer, and may
900 * relocate the given memory.
901 */
902 LIBXSMM_API_INTERN int libxsmm_malloc_attrib(void** memory, int flags,
903 /** If a name is given, an executable buffer will be dumped into a file. */
904 const char* name);
905
906 /** Allocate memory of the requested size, which is aligned according to the given alignment. */
907 LIBXSMM_API_INTERN int libxsmm_xmalloc(void** memory, size_t size, size_t alignment, int flags,
908 /* The extra information is stored along with the allocated chunk; can be NULL/zero. */
909 const void* extra, size_t extra_size);
910 /** Release memory, which was allocated using libxsmm_[*]malloc. */
911 LIBXSMM_API_INTERN void libxsmm_xfree(const void* memory, int check);
912
913 /** Like libxsmm_release_scratch, but takes a lock (can be NULL). */
914 LIBXSMM_API_INTERN void libxsmm_xrelease_scratch(LIBXSMM_LOCK_TYPE(LIBXSMM_LOCK)* lock);
915
916 /**
917 * Format for instance an amount of Bytes like libxsmm_format_size(result, sizeof(result), nbytes, "KMGT", "B", 10).
918 * The value returned is in requested/determined unit so that the user can decide about printing the buffer.
919 */
920 LIBXSMM_API_INTERN size_t libxsmm_format_size(char buffer[32], int buffer_size, size_t nbytes, const char scale[], const char* unit, int base);
921
922 /** Returns the type-name of data-type (can be also libxsmm_gemm_precision). */
923 LIBXSMM_API_INTERN const char* libxsmm_typename(libxsmm_datatype datatype);
924
925 /** Determines the given value in double-precision based on the given type. */
926 LIBXSMM_API_INTERN int libxsmm_dvalue(libxsmm_datatype datatype, const void* value, double* dvalue);
927
928 /** Services a build request, and (optionally) registers the code (use regindex=LIBXSMM_CAPACITY_REGISTRY for unmanaged code). */
929 LIBXSMM_API_INTERN int libxsmm_build(const libxsmm_build_request* request, unsigned int regindex, libxsmm_code_pointer* code);
930
931 /** Returns the type-size of data-type (can be also libxsmm_gemm_precision). */
932 LIBXSMM_API unsigned char libxsmm_typesize(libxsmm_datatype datatype);
933
934 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_kernel_xinfo {
935 /** Non-zero of kernel is registered. */
936 unsigned int registered;
937 /** Number of FLoating Point OPerationS (FLOPS). */
938 unsigned int nflops;
939 } libxsmm_kernel_xinfo;
940
941 /** Receive information about JIT-generated code. */
942 LIBXSMM_API_INTERN const libxsmm_kernel_xinfo* libxsmm_get_kernel_xinfo(libxsmm_code_pointer code, const libxsmm_descriptor** desc, size_t* code_size);
943
944 /** Calculates duration in seconds from given RTC ticks. */
945 LIBXSMM_API_INTERN double libxsmm_timer_duration_rtc(libxsmm_timer_tickint tick0, libxsmm_timer_tickint tick1);
946 /** Returns the current tick of platform-specific real-time clock. */
947 LIBXSMM_API_INTERN libxsmm_timer_tickint libxsmm_timer_tick_rtc(void);
948 /** Returns the current tick of a (monotonic) platform-specific counter. */
949 LIBXSMM_API_INTERN libxsmm_timer_tickint libxsmm_timer_tick_tsc(void);
950
951 LIBXSMM_API_INTERN void libxsmm_memory_init(int target_arch);
952 LIBXSMM_API_INTERN void libxsmm_memory_finalize(void);
953
954 LIBXSMM_API_INTERN void libxsmm_dnn_init(int target_arch);
955 LIBXSMM_API_INTERN void libxsmm_dnn_finalize(void);
956
957 /** intern function to calculate blockings, that's private API hence it's in this function */
958 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_get_feature_map_blocks(
959 int C, int K, int* C_block, int* K_block, int* fm_lp_block,
960 libxsmm_dnn_datatype datatype_in, libxsmm_dnn_datatype datatype_out);
961
962 /** Global lock; create an own lock for an independent domain. */
963 LIBXSMM_APIVAR_PUBLIC(LIBXSMM_LOCK_TYPE(LIBXSMM_LOCK) libxsmm_lock_global);
964 /** Determines whether a threaded implementation is synchronized or not. */
965 LIBXSMM_APIVAR_PUBLIC(int libxsmm_nosync);
966
967 /** Function used to allocate default memory. */
968 LIBXSMM_APIVAR_PRIVATE(libxsmm_malloc_function libxsmm_default_malloc_fn);
969 /** Function used to allocate scratch memory. */
970 LIBXSMM_APIVAR_PRIVATE(libxsmm_malloc_function libxsmm_scratch_malloc_fn);
971 /** Function used to release default memory. */
972 LIBXSMM_APIVAR_PRIVATE(libxsmm_free_function libxsmm_default_free_fn);
973 /** Function used to release scratch memory. */
974 LIBXSMM_APIVAR_PRIVATE(libxsmm_free_function libxsmm_scratch_free_fn);
975 /** If non-NULL, this context is used by the context-form of memory allocation. */
976 LIBXSMM_APIVAR_PRIVATE(const void* libxsmm_default_allocator_context);
977 /** If non-NULL, this context is used by the context-form of memory allocation. */
978 LIBXSMM_APIVAR_PRIVATE(const void* libxsmm_scratch_allocator_context);
979 /** Number of scratch memory pools used; clamped against internal maximum. */
980 LIBXSMM_APIVAR_PRIVATE(unsigned int libxsmm_scratch_pools);
981 /** Growth factor used to scale the scratch memory in case of reallocation. */
982 LIBXSMM_APIVAR_PRIVATE(double libxsmm_scratch_scale);
983 /** Number of seconds per RDTSC-cycle (zero or negative if RDTSC invalid). */
984 LIBXSMM_APIVAR_PRIVATE(double libxsmm_timer_scale);
985 /** Counts the number of attempts to create an SPMDM-handle. */
986 LIBXSMM_APIVAR_PRIVATE(unsigned int libxsmm_statistic_num_spmdm);
987 /** Counts the maximum number of thread that have been active. */
988 LIBXSMM_APIVAR_PRIVATE(unsigned int libxsmm_thread_count);
989
990 #if (0 != LIBXSMM_SYNC)
991 LIBXSMM_APIVAR_PRIVATE(LIBXSMM_TLS_TYPE libxsmm_tlskey);
992 #endif
993
994 #endif /*LIBXSMM_MAIN_H*/
995
996