1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Hans Pabst, Alexander Heinecke (Intel Corp.)
10 ******************************************************************************/
11 #include <libxsmm.h>
12 #include "libxsmm_main.h"
13 #include "libxsmm_dnn_tensor.h"
14 
15 #if defined(LIBXSMM_OFFLOAD_TARGET)
16 # pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
17 #endif
18 #include <math.h>
19 #if defined(_OPENMP)
20 # include <omp.h>
21 #endif
22 #if defined(LIBXSMM_OFFLOAD_TARGET)
23 # pragma offload_attribute(pop)
24 #endif
25 
26 
libxsmm_dnn_link_tensor(const libxsmm_dnn_tensor_datalayout * layout,const void * data,libxsmm_dnn_err_t * status)27 LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_link_tensor(const libxsmm_dnn_tensor_datalayout* layout, const void* data, libxsmm_dnn_err_t* status)
28 {
29   return libxsmm_dnn_link_qtensor(layout, data, 0, status);
30 }
31 
32 
libxsmm_dnn_link_qtensor(const libxsmm_dnn_tensor_datalayout * layout,const void * data,const unsigned char scf,libxsmm_dnn_err_t * status)33 LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_link_qtensor(const libxsmm_dnn_tensor_datalayout* layout, const void* data, const unsigned char scf, libxsmm_dnn_err_t* status)
34 {
35   libxsmm_dnn_tensor* tensor = (libxsmm_dnn_tensor*)malloc(sizeof(libxsmm_dnn_tensor));
36   *status = LIBXSMM_DNN_SUCCESS;
37 
38   if (layout != 0 && tensor != 0 && data != 0) {
39     memset(tensor, 0, sizeof(libxsmm_dnn_tensor));
40     tensor->layout = libxsmm_dnn_duplicate_tensor_datalayout(layout, status);
41     tensor->data = (void*)data;
42     tensor->scf = scf;
43     /* when layout copy failed, free layout */
44     if (*status != LIBXSMM_DNN_SUCCESS) {
45       libxsmm_dnn_destroy_tensor_datalayout(tensor->layout);
46     }
47   } else {
48     *status = LIBXSMM_DNN_ERR_CREATE_TENSOR;
49   }
50 
51   if (*status != LIBXSMM_DNN_SUCCESS) {
52     free((libxsmm_dnn_tensor*)tensor);
53     tensor = 0;
54   }
55 
56   return tensor;
57 }
58 
59 
libxsmm_dnn_duplicate_tensor_datalayout(const libxsmm_dnn_tensor_datalayout * layout,libxsmm_dnn_err_t * status)60 LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_duplicate_tensor_datalayout(const libxsmm_dnn_tensor_datalayout* layout, libxsmm_dnn_err_t* status) {
61   libxsmm_dnn_tensor_datalayout* dst_layout;
62 
63   *status = LIBXSMM_DNN_SUCCESS;
64   dst_layout = 0;
65 
66   if (layout != 0 && layout->num_dims != 0) {
67     unsigned int dim = 0;
68 
69     dst_layout = (libxsmm_dnn_tensor_datalayout*)malloc(sizeof(libxsmm_dnn_tensor_datalayout));
70     if (0 != dst_layout) {
71       memset(dst_layout, 0, sizeof(libxsmm_dnn_tensor_datalayout));
72       dst_layout->dim_type = (libxsmm_dnn_tensor_dimtype*)malloc(layout->num_dims * sizeof(libxsmm_dnn_tensor_dimtype));
73       dst_layout->dim_size = (unsigned int*)malloc(layout->num_dims * sizeof(unsigned int));
74       dst_layout->num_dims = layout->num_dims;
75       dst_layout->format = layout->format;
76       dst_layout->datatype = layout->datatype;
77       dst_layout->tensor_type = layout->tensor_type;
78       if (0 != dst_layout->dim_type && 0 != dst_layout->dim_size) {
79         for (dim = 0; dim < layout->num_dims; ++dim) {
80           dst_layout->dim_type[dim] = layout->dim_type[dim];
81           dst_layout->dim_size[dim] = layout->dim_size[dim];
82         }
83       } else {
84         *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT;
85       }
86     } else {
87       *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT;
88     }
89   } else {
90     *status = LIBXSMM_DNN_ERR_INVALID_LAYOUT;
91   }
92 
93   return dst_layout;
94 }
95 
96 
libxsmm_dnn_compare_tensor_datalayout(const libxsmm_dnn_tensor_datalayout * layout_a,const libxsmm_dnn_tensor_datalayout * layout_b,libxsmm_dnn_err_t * status)97 LIBXSMM_API unsigned int libxsmm_dnn_compare_tensor_datalayout(const libxsmm_dnn_tensor_datalayout* layout_a, const libxsmm_dnn_tensor_datalayout* layout_b, libxsmm_dnn_err_t* status) {
98   unsigned int result = 0;
99   *status = LIBXSMM_DNN_SUCCESS;
100 
101   if (layout_a != 0 && layout_b != 0) {
102     unsigned int dim = 0;
103 
104     if (layout_a->num_dims      != layout_b->num_dims)      { result = 1; }
105     if (layout_a->format        != layout_b->format)        { result = 1; }
106     if (layout_a->datatype      != layout_b->datatype)      { result = 1; }
107 
108     if (result == 0) {
109       for ( dim = 0; dim < layout_a->num_dims; ++dim ) {
110         if ( layout_a->dim_type[dim] != layout_b->dim_type[dim] ) { result = 1; }
111         if ( layout_a->dim_size[dim] != layout_b->dim_size[dim] ) { result = 1; }
112       }
113     }
114   } else {
115     *status = LIBXSMM_DNN_ERR_INVALID_LAYOUT;
116     result = 100;
117   }
118 
119   return result;
120 }
121 
122 
libxsmm_dnn_destroy_tensor_datalayout(libxsmm_dnn_tensor_datalayout * layout)123 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_tensor_datalayout(libxsmm_dnn_tensor_datalayout* layout) {
124   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
125 
126   if (0 != layout) {
127     free(layout->dim_type);
128     free(layout->dim_size);
129     free(layout);
130   }
131   else {
132     status = LIBXSMM_DNN_ERR_INVALID_LAYOUT;
133   }
134 
135   return status;
136 }
137 
138 
libxsmm_dnn_get_tensor_size(const libxsmm_dnn_tensor_datalayout * layout,libxsmm_dnn_err_t * status)139 LIBXSMM_API unsigned int libxsmm_dnn_get_tensor_size(const libxsmm_dnn_tensor_datalayout* layout, libxsmm_dnn_err_t* status) {
140   unsigned int size = 0;
141   *status = LIBXSMM_DNN_SUCCESS;
142 
143   if (0 != layout) {
144     unsigned int dim = 0;
145     size = (unsigned int)libxsmm_dnn_typesize(layout->datatype);
146     for (dim = 0; dim < layout->num_dims; ++dim) {
147       size *= layout->dim_size[dim];
148     }
149   }
150   else {
151     *status = LIBXSMM_DNN_ERR_INVALID_LAYOUT;
152   }
153 
154   return size;
155 }
156 
157 
libxsmm_dnn_get_tensor_elements(const libxsmm_dnn_tensor_datalayout * layout,libxsmm_dnn_err_t * status)158 LIBXSMM_API unsigned int libxsmm_dnn_get_tensor_elements(const libxsmm_dnn_tensor_datalayout* layout, libxsmm_dnn_err_t* status) {
159   unsigned int elements = 1;
160   *status = LIBXSMM_DNN_SUCCESS;
161 
162   if (0 != layout) {
163     unsigned int dim = 0;
164     for ( dim = 0; dim < layout->num_dims; ++dim ) {
165       elements *= layout->dim_size[dim];
166     }
167   } else {
168     *status = LIBXSMM_DNN_ERR_INVALID_LAYOUT;
169     elements = 0;
170   }
171 
172   return elements;
173 }
174 
175 
libxsmm_dnn_set_tensor_data_ptr(libxsmm_dnn_tensor * tensor,const void * data)176 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_set_tensor_data_ptr(libxsmm_dnn_tensor* tensor, const void* data) {
177   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
178 
179   if ((0 != tensor) && (0 != data)) {
180     if (0 != tensor->layout) {
181       if (0 < tensor->layout->num_dims) {
182         tensor->data = (void*)data;
183       } else {
184         status = LIBXSMM_DNN_ERR_INVALID_LAYOUT;
185       }
186     } else {
187       status = LIBXSMM_DNN_ERR_INVALID_LAYOUT;
188     }
189   }
190   else {
191     status = LIBXSMM_DNN_ERR_INVALID_TENSOR;
192   }
193 
194   return status;
195 }
196 
197 
libxsmm_dnn_get_tensor_data_ptr(const libxsmm_dnn_tensor * tensor,libxsmm_dnn_err_t * status)198 LIBXSMM_API void* libxsmm_dnn_get_tensor_data_ptr(const libxsmm_dnn_tensor* tensor, libxsmm_dnn_err_t* status)
199 {
200   *status = LIBXSMM_DNN_SUCCESS;
201 
202   if (0 != tensor) {
203     return tensor->data;
204   }
205   else {
206     *status = LIBXSMM_DNN_ERR_INVALID_TENSOR;
207   }
208 
209   return 0;
210 }
211 
212 
libxsmm_dnn_get_tensor_datalayout(const libxsmm_dnn_tensor * tensor,libxsmm_dnn_err_t * status)213 LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_get_tensor_datalayout(const libxsmm_dnn_tensor* tensor, libxsmm_dnn_err_t* status) {
214   libxsmm_dnn_tensor_datalayout* dst_layout = NULL;
215   *status = LIBXSMM_DNN_SUCCESS;
216 
217   if (0 != tensor) {
218     dst_layout = libxsmm_dnn_duplicate_tensor_datalayout( tensor->layout, status );
219   }
220   else {
221     *status = LIBXSMM_DNN_ERR_INVALID_TENSOR;
222   }
223 
224   return dst_layout;
225 }
226 
227 
libxsmm_dnn_get_qtensor_scf(const libxsmm_dnn_tensor * tensor,libxsmm_dnn_err_t * status)228 LIBXSMM_API unsigned char libxsmm_dnn_get_qtensor_scf(const libxsmm_dnn_tensor* tensor, libxsmm_dnn_err_t* status)
229 {
230   *status = LIBXSMM_DNN_SUCCESS;
231 
232   if (0 != tensor) {
233     return tensor->scf;
234   }
235   else {
236     *status = LIBXSMM_DNN_ERR_INVALID_TENSOR;
237   }
238 
239   return 0;
240 }
241 
242 
libxsmm_dnn_set_qtensor_scf(libxsmm_dnn_tensor * tensor,const unsigned char scf)243 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_set_qtensor_scf(libxsmm_dnn_tensor* tensor, const unsigned char scf)
244 {
245   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
246 
247   if (0 != tensor) {
248     tensor->scf = scf;
249   }
250   else {
251     status = LIBXSMM_DNN_ERR_INVALID_TENSOR;
252   }
253 
254   return status;
255 }
256 
257 
libxsmm_dnn_destroy_tensor(const libxsmm_dnn_tensor * tensor)258 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_tensor(const libxsmm_dnn_tensor* tensor)
259 {
260   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
261 
262   if (0 != tensor) { /* it is not an error attempting to destroy a NULL-handle */
263     /* free layout information stored in tensor */
264     if (0 != tensor->layout) {
265       libxsmm_dnn_destroy_tensor_datalayout( (libxsmm_dnn_tensor_datalayout*)tensor->layout );
266     }
267     /* deallocate handle structure */
268     free(/*remove constness*/(libxsmm_dnn_tensor*)tensor);
269   }
270 #if 0 /* releasing a NULL-buffer should be not an error (similar to freeing a NULL pointer) */
271   else {
272     status = LIBXSMM_DNN_ERR_INVALID_TENSOR;
273   }
274 #endif
275   return status;
276 }
277 
278 
libxsmm_dnn_copyin_tensor(const libxsmm_dnn_tensor * tensor,const void * data,const libxsmm_dnn_tensor_format in_format)279 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_copyin_tensor(const libxsmm_dnn_tensor* tensor, const void* data, const libxsmm_dnn_tensor_format in_format)
280 {
281   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
282 
283   /* @TODO check for valid combination */
284 
285   if (0 != tensor) {
286     switch (tensor->layout->tensor_type) {
287       case LIBXSMM_DNN_REGULAR_INPUT:
288       case LIBXSMM_DNN_GRADIENT_INPUT:
289       case LIBXSMM_DNN_REGULAR_OUTPUT:
290       case LIBXSMM_DNN_GRADIENT_OUTPUT:
291       case LIBXSMM_DNN_INPUT:
292       case LIBXSMM_DNN_OUTPUT:
293       case LIBXSMM_DNN_ACTIVATION: {
294                                      switch (in_format) {
295                                        case LIBXSMM_DNN_TENSOR_FORMAT_NCHW: {
296                                                                               if ( (tensor->layout->format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0 ) {
297                                                                                 switch (tensor->layout->datatype) {
298                                                                                   case LIBXSMM_DNN_DATATYPE_F32: {
299                                                                                                                    typedef float element_type;
300 #include "template/libxsmm_dnn_tensor_buffer_copy_in_nchw.tpl.c"
301                                                                                                                  } break;
302                                                                                   case LIBXSMM_DNN_DATATYPE_BF16: {
303                                                                                                                    typedef libxsmm_bfloat16 element_type;
304 #define LIBXSMM_DNN_COPY_LOW_PRECISION
305 #include "template/libxsmm_dnn_tensor_buffer_copy_in_nchw.tpl.c"
306 #undef LIBXSMM_DNN_COPY_LOW_PRECISION
307                                                                                                                  } break;
308                                                                                   case LIBXSMM_DNN_DATATYPE_I32: {
309                                                                                                                    typedef int element_type;
310 #include "template/libxsmm_dnn_tensor_buffer_copy_in_nchw.tpl.c"
311                                                                                                                  } break;
312                                                                                   case LIBXSMM_DNN_DATATYPE_I16: {
313                                                                                                                    typedef short  element_type;
314 #define LIBXSMM_DNN_COPY_LOW_PRECISION
315 #include "template/libxsmm_dnn_tensor_buffer_copy_in_nchw.tpl.c"
316 #undef LIBXSMM_DNN_COPY_LOW_PRECISION
317                                                                                                                  } break;
318                                                                                   case LIBXSMM_DNN_DATATYPE_I8: {
319                                                                                                                   typedef unsigned char element_type;
320 #define LIBXSMM_DNN_COPY_LOW_PRECISION
321 #include "template/libxsmm_dnn_tensor_buffer_copy_in_nchw.tpl.c"
322 #undef LIBXSMM_DNN_COPY_LOW_PRECISION
323                                                                                                                 } break;
324                                                                                   default: {
325                                                                                              status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
326                                                                                            }
327                                                                                 }
328                                                                               } else {
329                                                                                 status = LIBXSMM_DNN_ERR_UNSUPPORTED_DST_FORMAT;
330                                                                               }
331                                                                             } break;
332                                        default: {
333                                                   status = LIBXSMM_DNN_ERR_UNSUPPORTED_SRC_FORMAT;
334                                                 }
335                                      }
336                                    } break;
337       case LIBXSMM_DNN_REGULAR_FILTER:
338       case LIBXSMM_DNN_GRADIENT_FILTER:
339       case LIBXSMM_DNN_FILTER: {
340                                  switch (in_format) {
341                                    case LIBXSMM_DNN_TENSOR_FORMAT_KCRS: {
342                                                                           if ( (tensor->layout->format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0 ) {
343                                                                             switch (tensor->layout->datatype) {
344                                                                               case LIBXSMM_DNN_DATATYPE_F32: {
345                                                                                                                typedef float element_type;
346 #include "template/libxsmm_dnn_tensor_filter_copy_in_kcrs.tpl.c"
347                                                                                                              } break;
348                                                                               case LIBXSMM_DNN_DATATYPE_BF16: {
349                                                                                                                typedef libxsmm_bfloat16 element_type;
350 #include "template/libxsmm_dnn_tensor_filter_copy_in_kcrs.tpl.c"
351                                                                                                              } break;
352                                                                               case LIBXSMM_DNN_DATATYPE_I16: {
353                                                                                                                typedef short element_type;
354 #include "template/libxsmm_dnn_tensor_filter_copy_in_kcrs.tpl.c"
355                                                                                                              } break;
356                                                                               case LIBXSMM_DNN_DATATYPE_I8: {
357                                                                                                               typedef char element_type;
358 #include "template/libxsmm_dnn_tensor_filter_copy_in_kcrs.tpl.c"
359                                                                                                             } break;
360                                                                               default: {
361                                                                                          status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
362                                                                                        }
363                                                                             }
364                                                                           } else {
365                                                                             status = LIBXSMM_DNN_ERR_UNSUPPORTED_DST_FORMAT;
366                                                                           }
367                                                                         } break;
368                                    default: {
369                                               status = LIBXSMM_DNN_ERR_UNSUPPORTED_SRC_FORMAT;
370                                             }
371                                  }
372                                } break;
373       case LIBXSMM_DNN_REGULAR_CHANNEL_BIAS:
374       case LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS:
375       case LIBXSMM_DNN_CHANNEL_BIAS:
376       case LIBXSMM_DNN_REGULAR_CHANNEL_BETA:
377       case LIBXSMM_DNN_GRADIENT_CHANNEL_BETA:
378       case LIBXSMM_DNN_CHANNEL_BETA:
379       case LIBXSMM_DNN_REGULAR_CHANNEL_GAMMA:
380       case LIBXSMM_DNN_GRADIENT_CHANNEL_GAMMA:
381       case LIBXSMM_DNN_CHANNEL_GAMMA:
382       case LIBXSMM_DNN_CHANNEL_EXPECTVAL:
383       case LIBXSMM_DNN_CHANNEL_RCPSTDDEV:
384       case LIBXSMM_DNN_CHANNEL_VARIANCE:
385       case LIBXSMM_DNN_CHANNEL_SCALAR: {
386                                switch (in_format) {
387                                  case LIBXSMM_DNN_TENSOR_FORMAT_NCHW: {
388                                                                         if ( (tensor->layout->format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0 ) {
389                                                                           switch (tensor->layout->datatype) {
390                                                                             case LIBXSMM_DNN_DATATYPE_F32: {
391                                                                                                              typedef float element_type;
392 #include "template/libxsmm_dnn_tensor_bias_copy_in_nchw.tpl.c"
393                                                                                                            } break;
394                                                                             case LIBXSMM_DNN_DATATYPE_BF16: {
395                                                                                                              typedef libxsmm_bfloat16 element_type;
396 #define LIBXSMM_DNN_COPY_LOW_PRECISION
397 #include "template/libxsmm_dnn_tensor_bias_copy_in_nchw.tpl.c"
398 #undef LIBXSMM_DNN_COPY_LOW_PRECISION
399                                                                                                            } break;
400                                                                             case LIBXSMM_DNN_DATATYPE_I16: {
401                                                                                                              typedef short element_type;
402 #define LIBXSMM_DNN_COPY_LOW_PRECISION
403 #include "template/libxsmm_dnn_tensor_bias_copy_in_nchw.tpl.c"
404 #undef LIBXSMM_DNN_COPY_LOW_PRECISION
405                                                                                                            } break;
406                                                                             case LIBXSMM_DNN_DATATYPE_I8: {
407                                                                                                             typedef char element_type;
408 #define LIBXSMM_DNN_COPY_LOW_PRECISION
409 #include "template/libxsmm_dnn_tensor_bias_copy_in_nchw.tpl.c"
410 #undef LIBXSMM_DNN_COPY_LOW_PRECISION
411                                                                                                           } break;
412                                                                             default: {
413                                                                                        status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
414                                                                                      }
415                                                                           }
416                                                                         } else {
417                                                                           status = LIBXSMM_DNN_ERR_UNSUPPORTED_DST_FORMAT;
418                                                                         }
419                                                                       } break;
420                                  default: {
421                                             status = LIBXSMM_DNN_ERR_UNSUPPORTED_SRC_FORMAT;
422                                           }
423                                }
424                              } break;
425       default: {
426                  status = LIBXSMM_DNN_ERR_INVALID_TENSOR;
427                }
428     }
429   }
430   else {
431     status = LIBXSMM_DNN_ERR_INVALID_TENSOR;
432   }
433 
434   return status;
435 }
436 
437 
libxsmm_dnn_zero_tensor(const libxsmm_dnn_tensor * tensor)438 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_zero_tensor(const libxsmm_dnn_tensor* tensor)
439 {
440   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
441 
442   if (0 != tensor) {
443     const size_t size = libxsmm_dnn_get_tensor_elements( tensor->layout, &status );
444     size_t i;
445     /* use for-loops to potentially leverage NUMA in the future */
446     switch (tensor->layout->datatype) {
447       case LIBXSMM_DNN_DATATYPE_F32: {
448                                        float* fp32_data = (float*)tensor->data;
449                                        for (i = 0; i < size; ++i) fp32_data[i] = 0.0f;
450                                      } break;
451       case LIBXSMM_DNN_DATATYPE_BF16: {
452                                        libxsmm_bfloat16* bfp16_data = (libxsmm_bfloat16*)tensor->data;
453                                        for (i = 0; i < size; ++i) bfp16_data[i] = 0;
454                                      } break;
455       case LIBXSMM_DNN_DATATYPE_I32: {
456                                        int* int32_data = (int*)tensor->data;
457                                        for (i = 0; i < size; ++i) int32_data[i] = 0;
458                                      } break;
459       case LIBXSMM_DNN_DATATYPE_I16: {
460                                        short* int16_data = (short*)tensor->data;
461                                        for (i = 0; i < size; ++i) int16_data[i] = 0;
462                                      } break;
463       case LIBXSMM_DNN_DATATYPE_I8: {
464                                       char* int8_data = (char*)tensor->data;
465                                       for (i = 0; i < size; ++i) int8_data[i] = 0;
466                                     } break;
467       default: {
468         status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
469       }
470     }
471   }
472   else {
473     status = LIBXSMM_DNN_ERR_INVALID_TENSOR;
474   }
475 
476   return status;
477 }
478 
479 
libxsmm_dnn_copyout_tensor(const libxsmm_dnn_tensor * tensor,void * data,const libxsmm_dnn_tensor_format out_format)480 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_copyout_tensor(const libxsmm_dnn_tensor* tensor, void* data, const libxsmm_dnn_tensor_format out_format)
481 {
482   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
483 
484   /* @TODO check for valid combination */
485 
486   if (0 != tensor) {
487     switch (tensor->layout->tensor_type) {
488       case LIBXSMM_DNN_REGULAR_INPUT:
489       case LIBXSMM_DNN_GRADIENT_INPUT:
490       case LIBXSMM_DNN_REGULAR_OUTPUT:
491       case LIBXSMM_DNN_GRADIENT_OUTPUT:
492       case LIBXSMM_DNN_INPUT:
493       case LIBXSMM_DNN_OUTPUT:
494       case LIBXSMM_DNN_ACTIVATION: {
495                                      switch (out_format) {
496                                        case LIBXSMM_DNN_TENSOR_FORMAT_NCHW: {
497                                                                               if ( (tensor->layout->format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0 ) {
498                                                                                 switch (tensor->layout->datatype) {
499                                                                                   case LIBXSMM_DNN_DATATYPE_F32: {
500                                                                                                                    typedef float element_type;
501 #include "template/libxsmm_dnn_tensor_buffer_copy_out_nchw.tpl.c"
502                                                                                                                  } break;
503                                                                                   case LIBXSMM_DNN_DATATYPE_BF16: {
504                                                                                                                    typedef libxsmm_bfloat16 element_type;
505 #define LIBXSMM_DNN_COPY_LOW_PRECISION
506 #include "template/libxsmm_dnn_tensor_buffer_copy_out_nchw.tpl.c"
507 #undef LIBXSMM_DNN_COPY_LOW_PRECISION
508                                                                                                                  } break;
509                                                                                   case LIBXSMM_DNN_DATATYPE_I32: {
510                                                                                                                    typedef int element_type;
511 #include "template/libxsmm_dnn_tensor_buffer_copy_out_nchw.tpl.c"
512                                                                                                                  } break;
513                                                                                   case LIBXSMM_DNN_DATATYPE_I16: {
514                                                                                                                    typedef short element_type;
515 #define LIBXSMM_DNN_COPY_LOW_PRECISION
516 #include "template/libxsmm_dnn_tensor_buffer_copy_out_nchw.tpl.c"
517 #undef LIBXSMM_DNN_COPY_LOW_PRECISION
518                                                                                                                  } break;
519                                                                                   case LIBXSMM_DNN_DATATYPE_I8: {
520                                                                                                                   typedef unsigned char element_type;
521 #define LIBXSMM_DNN_COPY_LOW_PRECISION
522 #include "template/libxsmm_dnn_tensor_buffer_copy_out_nchw.tpl.c"
523 #undef LIBXSMM_DNN_COPY_LOW_PRECISION
524                                                                                                                 } break;
525                                                                                   default: {
526                                                                                              status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
527                                                                                            }
528                                                                                 }
529                                                                               } else {
530                                                                                 status = LIBXSMM_DNN_ERR_UNSUPPORTED_SRC_FORMAT;
531                                                                               }
532                                                                             } break;
533                                        default: {
534                                                   status = LIBXSMM_DNN_ERR_UNSUPPORTED_DST_FORMAT;
535                                                 }
536                                      }
537                                    } break;
538       case LIBXSMM_DNN_REGULAR_FILTER:
539       case LIBXSMM_DNN_GRADIENT_FILTER:
540       case LIBXSMM_DNN_FILTER: {
541                                  switch (out_format) {
542                                    case LIBXSMM_DNN_TENSOR_FORMAT_KCRS: {
543                                                                           if ( (tensor->layout->format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0 ) {
544                                                                             switch (tensor->layout->datatype) {
545                                                                               case LIBXSMM_DNN_DATATYPE_F32: {
546                                                                                                                typedef float element_type;
547 #include "template/libxsmm_dnn_tensor_filter_copy_out_kcrs.tpl.c"
548                                                                                                              } break;
549 
550                                                                               case LIBXSMM_DNN_DATATYPE_BF16: {
551                                                                                                                typedef libxsmm_bfloat16 element_type;
552 #include "template/libxsmm_dnn_tensor_filter_copy_out_kcrs.tpl.c"
553                                                                                                              } break;
554                                                                                    case LIBXSMM_DNN_DATATYPE_I32: {
555                                                                                                                    typedef int element_type;
556 #include "template/libxsmm_dnn_tensor_filter_copy_out_kcrs.tpl.c"
557                                                                                                                  } break;
558                                                                                    case LIBXSMM_DNN_DATATYPE_I16: {
559                                                                                                                typedef short  element_type;
560 #include "template/libxsmm_dnn_tensor_filter_copy_out_kcrs.tpl.c"
561                                                                                                              } break;
562                                                                               case LIBXSMM_DNN_DATATYPE_I8: {
563                                                                                                               typedef char element_type;
564 #include "template/libxsmm_dnn_tensor_filter_copy_out_kcrs.tpl.c"
565                                                                                                             } break;
566                                                                               default: {
567                                                                                          status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
568                                                                                        }
569                                                                             }
570                                                                           } else {
571                                                                             status = LIBXSMM_DNN_ERR_UNSUPPORTED_SRC_FORMAT;
572                                                                           }
573                                                                         } break;
574                                    default: {
575                                               status = LIBXSMM_DNN_ERR_UNSUPPORTED_DST_FORMAT;
576                                             }
577                                  }
578                                } break;
579       case LIBXSMM_DNN_REGULAR_CHANNEL_BIAS:
580       case LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS:
581       case LIBXSMM_DNN_CHANNEL_BIAS:
582       case LIBXSMM_DNN_REGULAR_CHANNEL_BETA:
583       case LIBXSMM_DNN_GRADIENT_CHANNEL_BETA:
584       case LIBXSMM_DNN_CHANNEL_BETA:
585       case LIBXSMM_DNN_REGULAR_CHANNEL_GAMMA:
586       case LIBXSMM_DNN_GRADIENT_CHANNEL_GAMMA:
587       case LIBXSMM_DNN_CHANNEL_GAMMA:
588       case LIBXSMM_DNN_CHANNEL_EXPECTVAL:
589       case LIBXSMM_DNN_CHANNEL_RCPSTDDEV:
590       case LIBXSMM_DNN_CHANNEL_VARIANCE:
591       case LIBXSMM_DNN_CHANNEL_SCALAR: {
592                                switch (out_format) {
593                                  case LIBXSMM_DNN_TENSOR_FORMAT_NCHW: {
594                                                                         if ( (tensor->layout->format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0 ) {
595                                                                           switch (tensor->layout->datatype) {
596                                                                             case LIBXSMM_DNN_DATATYPE_F32: {
597                                                                                                              typedef float element_type;
598 #include "template/libxsmm_dnn_tensor_bias_copy_out_nchw.tpl.c"
599                                                                                                            } break;
600                                                                             case LIBXSMM_DNN_DATATYPE_BF16: {
601                                                                                                              typedef libxsmm_bfloat16 element_type;
602 #define LIBXSMM_DNN_COPY_LOW_PRECISION
603 #include "template/libxsmm_dnn_tensor_bias_copy_out_nchw.tpl.c"
604 #undef LIBXSMM_DNN_COPY_LOW_PRECISION
605                                                                                                            } break;
606                                                                             case LIBXSMM_DNN_DATATYPE_I16: {
607                                                                                                              typedef short element_type;
608 #define LIBXSMM_DNN_COPY_LOW_PRECISION
609 #include "template/libxsmm_dnn_tensor_bias_copy_out_nchw.tpl.c"
610 #undef LIBXSMM_DNN_COPY_LOW_PRECISION
611                                                                                                            } break;
612                                                                             case LIBXSMM_DNN_DATATYPE_I8: {
613                                                                                                             typedef char element_type;
614 #define LIBXSMM_DNN_COPY_LOW_PRECISION
615 #include "template/libxsmm_dnn_tensor_bias_copy_out_nchw.tpl.c"
616 #undef LIBXSMM_DNN_COPY_LOW_PRECISION
617                                                                                                           } break;
618                                                                             default: {
619                                                                                        status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
620                                                                                      }
621                                                                           }
622                                                                         } else {
623                                                                           status = LIBXSMM_DNN_ERR_UNSUPPORTED_SRC_FORMAT;
624                                                                         }
625                                                                       } break;
626                                  default: {
627                                             status = LIBXSMM_DNN_ERR_UNSUPPORTED_DST_FORMAT;
628                                           }
629                                }
630                              } break;
631       default: {
632                  status = LIBXSMM_DNN_ERR_INVALID_TENSOR;
633                }
634     }
635   }
636   else {
637     status = LIBXSMM_DNN_ERR_INVALID_TENSOR;
638   }
639 
640   return status;
641 }
642 
643