1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved. *
3 * This file is part of the LIBXSMM library. *
4 * *
5 * For information on the license, see the LICENSE file. *
6 * Further information: https://github.com/hfp/libxsmm/ *
7 * SPDX-License-Identifier: BSD-3-Clause *
8 ******************************************************************************/
9 /* Hans Pabst, Alexander Heinecke (Intel Corp.)
10 ******************************************************************************/
11 #include <libxsmm.h>
12 #include "libxsmm_main.h"
13 #include "libxsmm_dnn_tensor.h"
14
15 #if defined(LIBXSMM_OFFLOAD_TARGET)
16 # pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
17 #endif
18 #include <math.h>
19 #if defined(_OPENMP)
20 # include <omp.h>
21 #endif
22 #if defined(LIBXSMM_OFFLOAD_TARGET)
23 # pragma offload_attribute(pop)
24 #endif
25
26
libxsmm_dnn_link_tensor(const libxsmm_dnn_tensor_datalayout * layout,const void * data,libxsmm_dnn_err_t * status)27 LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_link_tensor(const libxsmm_dnn_tensor_datalayout* layout, const void* data, libxsmm_dnn_err_t* status)
28 {
29 return libxsmm_dnn_link_qtensor(layout, data, 0, status);
30 }
31
32
libxsmm_dnn_link_qtensor(const libxsmm_dnn_tensor_datalayout * layout,const void * data,const unsigned char scf,libxsmm_dnn_err_t * status)33 LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_link_qtensor(const libxsmm_dnn_tensor_datalayout* layout, const void* data, const unsigned char scf, libxsmm_dnn_err_t* status)
34 {
35 libxsmm_dnn_tensor* tensor = (libxsmm_dnn_tensor*)malloc(sizeof(libxsmm_dnn_tensor));
36 *status = LIBXSMM_DNN_SUCCESS;
37
38 if (layout != 0 && tensor != 0 && data != 0) {
39 memset(tensor, 0, sizeof(libxsmm_dnn_tensor));
40 tensor->layout = libxsmm_dnn_duplicate_tensor_datalayout(layout, status);
41 tensor->data = (void*)data;
42 tensor->scf = scf;
43 /* when layout copy failed, free layout */
44 if (*status != LIBXSMM_DNN_SUCCESS) {
45 libxsmm_dnn_destroy_tensor_datalayout(tensor->layout);
46 }
47 } else {
48 *status = LIBXSMM_DNN_ERR_CREATE_TENSOR;
49 }
50
51 if (*status != LIBXSMM_DNN_SUCCESS) {
52 free((libxsmm_dnn_tensor*)tensor);
53 tensor = 0;
54 }
55
56 return tensor;
57 }
58
59
libxsmm_dnn_duplicate_tensor_datalayout(const libxsmm_dnn_tensor_datalayout * layout,libxsmm_dnn_err_t * status)60 LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_duplicate_tensor_datalayout(const libxsmm_dnn_tensor_datalayout* layout, libxsmm_dnn_err_t* status) {
61 libxsmm_dnn_tensor_datalayout* dst_layout;
62
63 *status = LIBXSMM_DNN_SUCCESS;
64 dst_layout = 0;
65
66 if (layout != 0 && layout->num_dims != 0) {
67 unsigned int dim = 0;
68
69 dst_layout = (libxsmm_dnn_tensor_datalayout*)malloc(sizeof(libxsmm_dnn_tensor_datalayout));
70 if (0 != dst_layout) {
71 memset(dst_layout, 0, sizeof(libxsmm_dnn_tensor_datalayout));
72 dst_layout->dim_type = (libxsmm_dnn_tensor_dimtype*)malloc(layout->num_dims * sizeof(libxsmm_dnn_tensor_dimtype));
73 dst_layout->dim_size = (unsigned int*)malloc(layout->num_dims * sizeof(unsigned int));
74 dst_layout->num_dims = layout->num_dims;
75 dst_layout->format = layout->format;
76 dst_layout->datatype = layout->datatype;
77 dst_layout->tensor_type = layout->tensor_type;
78 if (0 != dst_layout->dim_type && 0 != dst_layout->dim_size) {
79 for (dim = 0; dim < layout->num_dims; ++dim) {
80 dst_layout->dim_type[dim] = layout->dim_type[dim];
81 dst_layout->dim_size[dim] = layout->dim_size[dim];
82 }
83 } else {
84 *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT;
85 }
86 } else {
87 *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT;
88 }
89 } else {
90 *status = LIBXSMM_DNN_ERR_INVALID_LAYOUT;
91 }
92
93 return dst_layout;
94 }
95
96
libxsmm_dnn_compare_tensor_datalayout(const libxsmm_dnn_tensor_datalayout * layout_a,const libxsmm_dnn_tensor_datalayout * layout_b,libxsmm_dnn_err_t * status)97 LIBXSMM_API unsigned int libxsmm_dnn_compare_tensor_datalayout(const libxsmm_dnn_tensor_datalayout* layout_a, const libxsmm_dnn_tensor_datalayout* layout_b, libxsmm_dnn_err_t* status) {
98 unsigned int result = 0;
99 *status = LIBXSMM_DNN_SUCCESS;
100
101 if (layout_a != 0 && layout_b != 0) {
102 unsigned int dim = 0;
103
104 if (layout_a->num_dims != layout_b->num_dims) { result = 1; }
105 if (layout_a->format != layout_b->format) { result = 1; }
106 if (layout_a->datatype != layout_b->datatype) { result = 1; }
107
108 if (result == 0) {
109 for ( dim = 0; dim < layout_a->num_dims; ++dim ) {
110 if ( layout_a->dim_type[dim] != layout_b->dim_type[dim] ) { result = 1; }
111 if ( layout_a->dim_size[dim] != layout_b->dim_size[dim] ) { result = 1; }
112 }
113 }
114 } else {
115 *status = LIBXSMM_DNN_ERR_INVALID_LAYOUT;
116 result = 100;
117 }
118
119 return result;
120 }
121
122
libxsmm_dnn_destroy_tensor_datalayout(libxsmm_dnn_tensor_datalayout * layout)123 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_tensor_datalayout(libxsmm_dnn_tensor_datalayout* layout) {
124 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
125
126 if (0 != layout) {
127 free(layout->dim_type);
128 free(layout->dim_size);
129 free(layout);
130 }
131 else {
132 status = LIBXSMM_DNN_ERR_INVALID_LAYOUT;
133 }
134
135 return status;
136 }
137
138
libxsmm_dnn_get_tensor_size(const libxsmm_dnn_tensor_datalayout * layout,libxsmm_dnn_err_t * status)139 LIBXSMM_API unsigned int libxsmm_dnn_get_tensor_size(const libxsmm_dnn_tensor_datalayout* layout, libxsmm_dnn_err_t* status) {
140 unsigned int size = 0;
141 *status = LIBXSMM_DNN_SUCCESS;
142
143 if (0 != layout) {
144 unsigned int dim = 0;
145 size = (unsigned int)libxsmm_dnn_typesize(layout->datatype);
146 for (dim = 0; dim < layout->num_dims; ++dim) {
147 size *= layout->dim_size[dim];
148 }
149 }
150 else {
151 *status = LIBXSMM_DNN_ERR_INVALID_LAYOUT;
152 }
153
154 return size;
155 }
156
157
libxsmm_dnn_get_tensor_elements(const libxsmm_dnn_tensor_datalayout * layout,libxsmm_dnn_err_t * status)158 LIBXSMM_API unsigned int libxsmm_dnn_get_tensor_elements(const libxsmm_dnn_tensor_datalayout* layout, libxsmm_dnn_err_t* status) {
159 unsigned int elements = 1;
160 *status = LIBXSMM_DNN_SUCCESS;
161
162 if (0 != layout) {
163 unsigned int dim = 0;
164 for ( dim = 0; dim < layout->num_dims; ++dim ) {
165 elements *= layout->dim_size[dim];
166 }
167 } else {
168 *status = LIBXSMM_DNN_ERR_INVALID_LAYOUT;
169 elements = 0;
170 }
171
172 return elements;
173 }
174
175
libxsmm_dnn_set_tensor_data_ptr(libxsmm_dnn_tensor * tensor,const void * data)176 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_set_tensor_data_ptr(libxsmm_dnn_tensor* tensor, const void* data) {
177 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
178
179 if ((0 != tensor) && (0 != data)) {
180 if (0 != tensor->layout) {
181 if (0 < tensor->layout->num_dims) {
182 tensor->data = (void*)data;
183 } else {
184 status = LIBXSMM_DNN_ERR_INVALID_LAYOUT;
185 }
186 } else {
187 status = LIBXSMM_DNN_ERR_INVALID_LAYOUT;
188 }
189 }
190 else {
191 status = LIBXSMM_DNN_ERR_INVALID_TENSOR;
192 }
193
194 return status;
195 }
196
197
libxsmm_dnn_get_tensor_data_ptr(const libxsmm_dnn_tensor * tensor,libxsmm_dnn_err_t * status)198 LIBXSMM_API void* libxsmm_dnn_get_tensor_data_ptr(const libxsmm_dnn_tensor* tensor, libxsmm_dnn_err_t* status)
199 {
200 *status = LIBXSMM_DNN_SUCCESS;
201
202 if (0 != tensor) {
203 return tensor->data;
204 }
205 else {
206 *status = LIBXSMM_DNN_ERR_INVALID_TENSOR;
207 }
208
209 return 0;
210 }
211
212
libxsmm_dnn_get_tensor_datalayout(const libxsmm_dnn_tensor * tensor,libxsmm_dnn_err_t * status)213 LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_get_tensor_datalayout(const libxsmm_dnn_tensor* tensor, libxsmm_dnn_err_t* status) {
214 libxsmm_dnn_tensor_datalayout* dst_layout = NULL;
215 *status = LIBXSMM_DNN_SUCCESS;
216
217 if (0 != tensor) {
218 dst_layout = libxsmm_dnn_duplicate_tensor_datalayout( tensor->layout, status );
219 }
220 else {
221 *status = LIBXSMM_DNN_ERR_INVALID_TENSOR;
222 }
223
224 return dst_layout;
225 }
226
227
libxsmm_dnn_get_qtensor_scf(const libxsmm_dnn_tensor * tensor,libxsmm_dnn_err_t * status)228 LIBXSMM_API unsigned char libxsmm_dnn_get_qtensor_scf(const libxsmm_dnn_tensor* tensor, libxsmm_dnn_err_t* status)
229 {
230 *status = LIBXSMM_DNN_SUCCESS;
231
232 if (0 != tensor) {
233 return tensor->scf;
234 }
235 else {
236 *status = LIBXSMM_DNN_ERR_INVALID_TENSOR;
237 }
238
239 return 0;
240 }
241
242
libxsmm_dnn_set_qtensor_scf(libxsmm_dnn_tensor * tensor,const unsigned char scf)243 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_set_qtensor_scf(libxsmm_dnn_tensor* tensor, const unsigned char scf)
244 {
245 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
246
247 if (0 != tensor) {
248 tensor->scf = scf;
249 }
250 else {
251 status = LIBXSMM_DNN_ERR_INVALID_TENSOR;
252 }
253
254 return status;
255 }
256
257
libxsmm_dnn_destroy_tensor(const libxsmm_dnn_tensor * tensor)258 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_tensor(const libxsmm_dnn_tensor* tensor)
259 {
260 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
261
262 if (0 != tensor) { /* it is not an error attempting to destroy a NULL-handle */
263 /* free layout information stored in tensor */
264 if (0 != tensor->layout) {
265 libxsmm_dnn_destroy_tensor_datalayout( (libxsmm_dnn_tensor_datalayout*)tensor->layout );
266 }
267 /* deallocate handle structure */
268 free(/*remove constness*/(libxsmm_dnn_tensor*)tensor);
269 }
270 #if 0 /* releasing a NULL-buffer should be not an error (similar to freeing a NULL pointer) */
271 else {
272 status = LIBXSMM_DNN_ERR_INVALID_TENSOR;
273 }
274 #endif
275 return status;
276 }
277
278
libxsmm_dnn_copyin_tensor(const libxsmm_dnn_tensor * tensor,const void * data,const libxsmm_dnn_tensor_format in_format)279 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_copyin_tensor(const libxsmm_dnn_tensor* tensor, const void* data, const libxsmm_dnn_tensor_format in_format)
280 {
281 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
282
283 /* @TODO check for valid combination */
284
285 if (0 != tensor) {
286 switch (tensor->layout->tensor_type) {
287 case LIBXSMM_DNN_REGULAR_INPUT:
288 case LIBXSMM_DNN_GRADIENT_INPUT:
289 case LIBXSMM_DNN_REGULAR_OUTPUT:
290 case LIBXSMM_DNN_GRADIENT_OUTPUT:
291 case LIBXSMM_DNN_INPUT:
292 case LIBXSMM_DNN_OUTPUT:
293 case LIBXSMM_DNN_ACTIVATION: {
294 switch (in_format) {
295 case LIBXSMM_DNN_TENSOR_FORMAT_NCHW: {
296 if ( (tensor->layout->format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0 ) {
297 switch (tensor->layout->datatype) {
298 case LIBXSMM_DNN_DATATYPE_F32: {
299 typedef float element_type;
300 #include "template/libxsmm_dnn_tensor_buffer_copy_in_nchw.tpl.c"
301 } break;
302 case LIBXSMM_DNN_DATATYPE_BF16: {
303 typedef libxsmm_bfloat16 element_type;
304 #define LIBXSMM_DNN_COPY_LOW_PRECISION
305 #include "template/libxsmm_dnn_tensor_buffer_copy_in_nchw.tpl.c"
306 #undef LIBXSMM_DNN_COPY_LOW_PRECISION
307 } break;
308 case LIBXSMM_DNN_DATATYPE_I32: {
309 typedef int element_type;
310 #include "template/libxsmm_dnn_tensor_buffer_copy_in_nchw.tpl.c"
311 } break;
312 case LIBXSMM_DNN_DATATYPE_I16: {
313 typedef short element_type;
314 #define LIBXSMM_DNN_COPY_LOW_PRECISION
315 #include "template/libxsmm_dnn_tensor_buffer_copy_in_nchw.tpl.c"
316 #undef LIBXSMM_DNN_COPY_LOW_PRECISION
317 } break;
318 case LIBXSMM_DNN_DATATYPE_I8: {
319 typedef unsigned char element_type;
320 #define LIBXSMM_DNN_COPY_LOW_PRECISION
321 #include "template/libxsmm_dnn_tensor_buffer_copy_in_nchw.tpl.c"
322 #undef LIBXSMM_DNN_COPY_LOW_PRECISION
323 } break;
324 default: {
325 status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
326 }
327 }
328 } else {
329 status = LIBXSMM_DNN_ERR_UNSUPPORTED_DST_FORMAT;
330 }
331 } break;
332 default: {
333 status = LIBXSMM_DNN_ERR_UNSUPPORTED_SRC_FORMAT;
334 }
335 }
336 } break;
337 case LIBXSMM_DNN_REGULAR_FILTER:
338 case LIBXSMM_DNN_GRADIENT_FILTER:
339 case LIBXSMM_DNN_FILTER: {
340 switch (in_format) {
341 case LIBXSMM_DNN_TENSOR_FORMAT_KCRS: {
342 if ( (tensor->layout->format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0 ) {
343 switch (tensor->layout->datatype) {
344 case LIBXSMM_DNN_DATATYPE_F32: {
345 typedef float element_type;
346 #include "template/libxsmm_dnn_tensor_filter_copy_in_kcrs.tpl.c"
347 } break;
348 case LIBXSMM_DNN_DATATYPE_BF16: {
349 typedef libxsmm_bfloat16 element_type;
350 #include "template/libxsmm_dnn_tensor_filter_copy_in_kcrs.tpl.c"
351 } break;
352 case LIBXSMM_DNN_DATATYPE_I16: {
353 typedef short element_type;
354 #include "template/libxsmm_dnn_tensor_filter_copy_in_kcrs.tpl.c"
355 } break;
356 case LIBXSMM_DNN_DATATYPE_I8: {
357 typedef char element_type;
358 #include "template/libxsmm_dnn_tensor_filter_copy_in_kcrs.tpl.c"
359 } break;
360 default: {
361 status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
362 }
363 }
364 } else {
365 status = LIBXSMM_DNN_ERR_UNSUPPORTED_DST_FORMAT;
366 }
367 } break;
368 default: {
369 status = LIBXSMM_DNN_ERR_UNSUPPORTED_SRC_FORMAT;
370 }
371 }
372 } break;
373 case LIBXSMM_DNN_REGULAR_CHANNEL_BIAS:
374 case LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS:
375 case LIBXSMM_DNN_CHANNEL_BIAS:
376 case LIBXSMM_DNN_REGULAR_CHANNEL_BETA:
377 case LIBXSMM_DNN_GRADIENT_CHANNEL_BETA:
378 case LIBXSMM_DNN_CHANNEL_BETA:
379 case LIBXSMM_DNN_REGULAR_CHANNEL_GAMMA:
380 case LIBXSMM_DNN_GRADIENT_CHANNEL_GAMMA:
381 case LIBXSMM_DNN_CHANNEL_GAMMA:
382 case LIBXSMM_DNN_CHANNEL_EXPECTVAL:
383 case LIBXSMM_DNN_CHANNEL_RCPSTDDEV:
384 case LIBXSMM_DNN_CHANNEL_VARIANCE:
385 case LIBXSMM_DNN_CHANNEL_SCALAR: {
386 switch (in_format) {
387 case LIBXSMM_DNN_TENSOR_FORMAT_NCHW: {
388 if ( (tensor->layout->format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0 ) {
389 switch (tensor->layout->datatype) {
390 case LIBXSMM_DNN_DATATYPE_F32: {
391 typedef float element_type;
392 #include "template/libxsmm_dnn_tensor_bias_copy_in_nchw.tpl.c"
393 } break;
394 case LIBXSMM_DNN_DATATYPE_BF16: {
395 typedef libxsmm_bfloat16 element_type;
396 #define LIBXSMM_DNN_COPY_LOW_PRECISION
397 #include "template/libxsmm_dnn_tensor_bias_copy_in_nchw.tpl.c"
398 #undef LIBXSMM_DNN_COPY_LOW_PRECISION
399 } break;
400 case LIBXSMM_DNN_DATATYPE_I16: {
401 typedef short element_type;
402 #define LIBXSMM_DNN_COPY_LOW_PRECISION
403 #include "template/libxsmm_dnn_tensor_bias_copy_in_nchw.tpl.c"
404 #undef LIBXSMM_DNN_COPY_LOW_PRECISION
405 } break;
406 case LIBXSMM_DNN_DATATYPE_I8: {
407 typedef char element_type;
408 #define LIBXSMM_DNN_COPY_LOW_PRECISION
409 #include "template/libxsmm_dnn_tensor_bias_copy_in_nchw.tpl.c"
410 #undef LIBXSMM_DNN_COPY_LOW_PRECISION
411 } break;
412 default: {
413 status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
414 }
415 }
416 } else {
417 status = LIBXSMM_DNN_ERR_UNSUPPORTED_DST_FORMAT;
418 }
419 } break;
420 default: {
421 status = LIBXSMM_DNN_ERR_UNSUPPORTED_SRC_FORMAT;
422 }
423 }
424 } break;
425 default: {
426 status = LIBXSMM_DNN_ERR_INVALID_TENSOR;
427 }
428 }
429 }
430 else {
431 status = LIBXSMM_DNN_ERR_INVALID_TENSOR;
432 }
433
434 return status;
435 }
436
437
libxsmm_dnn_zero_tensor(const libxsmm_dnn_tensor * tensor)438 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_zero_tensor(const libxsmm_dnn_tensor* tensor)
439 {
440 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
441
442 if (0 != tensor) {
443 const size_t size = libxsmm_dnn_get_tensor_elements( tensor->layout, &status );
444 size_t i;
445 /* use for-loops to potentially leverage NUMA in the future */
446 switch (tensor->layout->datatype) {
447 case LIBXSMM_DNN_DATATYPE_F32: {
448 float* fp32_data = (float*)tensor->data;
449 for (i = 0; i < size; ++i) fp32_data[i] = 0.0f;
450 } break;
451 case LIBXSMM_DNN_DATATYPE_BF16: {
452 libxsmm_bfloat16* bfp16_data = (libxsmm_bfloat16*)tensor->data;
453 for (i = 0; i < size; ++i) bfp16_data[i] = 0;
454 } break;
455 case LIBXSMM_DNN_DATATYPE_I32: {
456 int* int32_data = (int*)tensor->data;
457 for (i = 0; i < size; ++i) int32_data[i] = 0;
458 } break;
459 case LIBXSMM_DNN_DATATYPE_I16: {
460 short* int16_data = (short*)tensor->data;
461 for (i = 0; i < size; ++i) int16_data[i] = 0;
462 } break;
463 case LIBXSMM_DNN_DATATYPE_I8: {
464 char* int8_data = (char*)tensor->data;
465 for (i = 0; i < size; ++i) int8_data[i] = 0;
466 } break;
467 default: {
468 status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
469 }
470 }
471 }
472 else {
473 status = LIBXSMM_DNN_ERR_INVALID_TENSOR;
474 }
475
476 return status;
477 }
478
479
libxsmm_dnn_copyout_tensor(const libxsmm_dnn_tensor * tensor,void * data,const libxsmm_dnn_tensor_format out_format)480 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_copyout_tensor(const libxsmm_dnn_tensor* tensor, void* data, const libxsmm_dnn_tensor_format out_format)
481 {
482 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
483
484 /* @TODO check for valid combination */
485
486 if (0 != tensor) {
487 switch (tensor->layout->tensor_type) {
488 case LIBXSMM_DNN_REGULAR_INPUT:
489 case LIBXSMM_DNN_GRADIENT_INPUT:
490 case LIBXSMM_DNN_REGULAR_OUTPUT:
491 case LIBXSMM_DNN_GRADIENT_OUTPUT:
492 case LIBXSMM_DNN_INPUT:
493 case LIBXSMM_DNN_OUTPUT:
494 case LIBXSMM_DNN_ACTIVATION: {
495 switch (out_format) {
496 case LIBXSMM_DNN_TENSOR_FORMAT_NCHW: {
497 if ( (tensor->layout->format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0 ) {
498 switch (tensor->layout->datatype) {
499 case LIBXSMM_DNN_DATATYPE_F32: {
500 typedef float element_type;
501 #include "template/libxsmm_dnn_tensor_buffer_copy_out_nchw.tpl.c"
502 } break;
503 case LIBXSMM_DNN_DATATYPE_BF16: {
504 typedef libxsmm_bfloat16 element_type;
505 #define LIBXSMM_DNN_COPY_LOW_PRECISION
506 #include "template/libxsmm_dnn_tensor_buffer_copy_out_nchw.tpl.c"
507 #undef LIBXSMM_DNN_COPY_LOW_PRECISION
508 } break;
509 case LIBXSMM_DNN_DATATYPE_I32: {
510 typedef int element_type;
511 #include "template/libxsmm_dnn_tensor_buffer_copy_out_nchw.tpl.c"
512 } break;
513 case LIBXSMM_DNN_DATATYPE_I16: {
514 typedef short element_type;
515 #define LIBXSMM_DNN_COPY_LOW_PRECISION
516 #include "template/libxsmm_dnn_tensor_buffer_copy_out_nchw.tpl.c"
517 #undef LIBXSMM_DNN_COPY_LOW_PRECISION
518 } break;
519 case LIBXSMM_DNN_DATATYPE_I8: {
520 typedef unsigned char element_type;
521 #define LIBXSMM_DNN_COPY_LOW_PRECISION
522 #include "template/libxsmm_dnn_tensor_buffer_copy_out_nchw.tpl.c"
523 #undef LIBXSMM_DNN_COPY_LOW_PRECISION
524 } break;
525 default: {
526 status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
527 }
528 }
529 } else {
530 status = LIBXSMM_DNN_ERR_UNSUPPORTED_SRC_FORMAT;
531 }
532 } break;
533 default: {
534 status = LIBXSMM_DNN_ERR_UNSUPPORTED_DST_FORMAT;
535 }
536 }
537 } break;
538 case LIBXSMM_DNN_REGULAR_FILTER:
539 case LIBXSMM_DNN_GRADIENT_FILTER:
540 case LIBXSMM_DNN_FILTER: {
541 switch (out_format) {
542 case LIBXSMM_DNN_TENSOR_FORMAT_KCRS: {
543 if ( (tensor->layout->format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0 ) {
544 switch (tensor->layout->datatype) {
545 case LIBXSMM_DNN_DATATYPE_F32: {
546 typedef float element_type;
547 #include "template/libxsmm_dnn_tensor_filter_copy_out_kcrs.tpl.c"
548 } break;
549
550 case LIBXSMM_DNN_DATATYPE_BF16: {
551 typedef libxsmm_bfloat16 element_type;
552 #include "template/libxsmm_dnn_tensor_filter_copy_out_kcrs.tpl.c"
553 } break;
554 case LIBXSMM_DNN_DATATYPE_I32: {
555 typedef int element_type;
556 #include "template/libxsmm_dnn_tensor_filter_copy_out_kcrs.tpl.c"
557 } break;
558 case LIBXSMM_DNN_DATATYPE_I16: {
559 typedef short element_type;
560 #include "template/libxsmm_dnn_tensor_filter_copy_out_kcrs.tpl.c"
561 } break;
562 case LIBXSMM_DNN_DATATYPE_I8: {
563 typedef char element_type;
564 #include "template/libxsmm_dnn_tensor_filter_copy_out_kcrs.tpl.c"
565 } break;
566 default: {
567 status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
568 }
569 }
570 } else {
571 status = LIBXSMM_DNN_ERR_UNSUPPORTED_SRC_FORMAT;
572 }
573 } break;
574 default: {
575 status = LIBXSMM_DNN_ERR_UNSUPPORTED_DST_FORMAT;
576 }
577 }
578 } break;
579 case LIBXSMM_DNN_REGULAR_CHANNEL_BIAS:
580 case LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS:
581 case LIBXSMM_DNN_CHANNEL_BIAS:
582 case LIBXSMM_DNN_REGULAR_CHANNEL_BETA:
583 case LIBXSMM_DNN_GRADIENT_CHANNEL_BETA:
584 case LIBXSMM_DNN_CHANNEL_BETA:
585 case LIBXSMM_DNN_REGULAR_CHANNEL_GAMMA:
586 case LIBXSMM_DNN_GRADIENT_CHANNEL_GAMMA:
587 case LIBXSMM_DNN_CHANNEL_GAMMA:
588 case LIBXSMM_DNN_CHANNEL_EXPECTVAL:
589 case LIBXSMM_DNN_CHANNEL_RCPSTDDEV:
590 case LIBXSMM_DNN_CHANNEL_VARIANCE:
591 case LIBXSMM_DNN_CHANNEL_SCALAR: {
592 switch (out_format) {
593 case LIBXSMM_DNN_TENSOR_FORMAT_NCHW: {
594 if ( (tensor->layout->format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0 ) {
595 switch (tensor->layout->datatype) {
596 case LIBXSMM_DNN_DATATYPE_F32: {
597 typedef float element_type;
598 #include "template/libxsmm_dnn_tensor_bias_copy_out_nchw.tpl.c"
599 } break;
600 case LIBXSMM_DNN_DATATYPE_BF16: {
601 typedef libxsmm_bfloat16 element_type;
602 #define LIBXSMM_DNN_COPY_LOW_PRECISION
603 #include "template/libxsmm_dnn_tensor_bias_copy_out_nchw.tpl.c"
604 #undef LIBXSMM_DNN_COPY_LOW_PRECISION
605 } break;
606 case LIBXSMM_DNN_DATATYPE_I16: {
607 typedef short element_type;
608 #define LIBXSMM_DNN_COPY_LOW_PRECISION
609 #include "template/libxsmm_dnn_tensor_bias_copy_out_nchw.tpl.c"
610 #undef LIBXSMM_DNN_COPY_LOW_PRECISION
611 } break;
612 case LIBXSMM_DNN_DATATYPE_I8: {
613 typedef char element_type;
614 #define LIBXSMM_DNN_COPY_LOW_PRECISION
615 #include "template/libxsmm_dnn_tensor_bias_copy_out_nchw.tpl.c"
616 #undef LIBXSMM_DNN_COPY_LOW_PRECISION
617 } break;
618 default: {
619 status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
620 }
621 }
622 } else {
623 status = LIBXSMM_DNN_ERR_UNSUPPORTED_SRC_FORMAT;
624 }
625 } break;
626 default: {
627 status = LIBXSMM_DNN_ERR_UNSUPPORTED_DST_FORMAT;
628 }
629 }
630 } break;
631 default: {
632 status = LIBXSMM_DNN_ERR_INVALID_TENSOR;
633 }
634 }
635 }
636 else {
637 status = LIBXSMM_DNN_ERR_INVALID_TENSOR;
638 }
639
640 return status;
641 }
642
643