1 /****************************************************************************** 2 * Copyright (c) Intel Corporation - All rights reserved. * 3 * This file is part of the LIBXSMM library. * 4 * * 5 * For information on the license, see the LICENSE file. * 6 * Further information: https://github.com/hfp/libxsmm/ * 7 * SPDX-License-Identifier: BSD-3-Clause * 8 ******************************************************************************/ 9 /* Alexander Heinecke (Intel Corp.) 10 ******************************************************************************/ 11 #ifndef LIBXSMM_DNN_TENSOR_H 12 #define LIBXSMM_DNN_TENSOR_H 13 14 #include "libxsmm_typedefs.h" 15 #include "libxsmm_dnn.h" 16 17 /** Opaque handles which represents convolutions and LIBXSMM datatypes */ 18 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_dnn_tensor libxsmm_dnn_tensor; 19 20 typedef enum libxsmm_dnn_tensor_dimtype { 21 /** Mini-batch */ 22 LIBXSMM_DNN_TENSOR_DIMTYPE_N, 23 /** Image Height */ 24 LIBXSMM_DNN_TENSOR_DIMTYPE_H, 25 /** Image Width */ 26 LIBXSMM_DNN_TENSOR_DIMTYPE_W, 27 /** channels or input channels */ 28 LIBXSMM_DNN_TENSOR_DIMTYPE_C, 29 /** output channels */ 30 LIBXSMM_DNN_TENSOR_DIMTYPE_K, 31 /** kernel height */ 32 LIBXSMM_DNN_TENSOR_DIMTYPE_R, 33 /** kernel width */ 34 LIBXSMM_DNN_TENSOR_DIMTYPE_S, 35 /** sequence lenth counter */ 36 LIBXSMM_DNN_TENSOR_DIMTYPE_T, 37 /** channle group counter */ 38 LIBXSMM_DNN_TENSOR_DIMTYPE_G, 39 /** general counter */ 40 LIBXSMM_DNN_TENSOR_DIMTYPE_X 41 } libxsmm_dnn_tensor_dimtype; 42 43 /** types of different buffers */ 44 typedef enum libxsmm_dnn_tensor_type { 45 /** regular input buffer */ 46 LIBXSMM_DNN_REGULAR_INPUT, 47 /** regular input buffer */ 48 LIBXSMM_DNN_REGULAR_INPUT_ADD, 49 /** regular input buffer, transpose */ 50 LIBXSMM_DNN_REGULAR_INPUT_TRANS, 51 /** gradient input buffer */ 52 LIBXSMM_DNN_GRADIENT_INPUT, 53 /** gradient input buffer */ 54 LIBXSMM_DNN_GRADIENT_INPUT_ADD, 55 /** regular output buffer */ 56 LIBXSMM_DNN_REGULAR_OUTPUT, 57 /** gradient output buffer */ 58 LIBXSMM_DNN_GRADIENT_OUTPUT, 59 /** general input type */ 60 LIBXSMM_DNN_INPUT, 61 /** general output type */ 62 LIBXSMM_DNN_OUTPUT, 63 /** general activation type */ 64 LIBXSMM_DNN_ACTIVATION, 65 /* regular filter */ 66 LIBXSMM_DNN_REGULAR_FILTER, 67 /* regular filter */ 68 LIBXSMM_DNN_REGULAR_FILTER_TRANS, 69 /* gradient filter */ 70 LIBXSMM_DNN_GRADIENT_FILTER, 71 /* master filter */ 72 LIBXSMM_DNN_MASTER_FILTER, 73 /** general filter type */ 74 LIBXSMM_DNN_FILTER, 75 /* regular bias */ 76 LIBXSMM_DNN_REGULAR_CHANNEL_BIAS, 77 /* gradient bias */ 78 LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS, 79 /* bias */ 80 LIBXSMM_DNN_CHANNEL_BIAS, 81 /* regular beta */ 82 LIBXSMM_DNN_REGULAR_CHANNEL_BETA, 83 /* gradient beta */ 84 LIBXSMM_DNN_GRADIENT_CHANNEL_BETA, 85 /* beta */ 86 LIBXSMM_DNN_CHANNEL_BETA, 87 /* regular gamma */ 88 LIBXSMM_DNN_REGULAR_CHANNEL_GAMMA, 89 /* gradient gamma */ 90 LIBXSMM_DNN_GRADIENT_CHANNEL_GAMMA, 91 /* Gamma */ 92 LIBXSMM_DNN_CHANNEL_GAMMA, 93 /* regular beta */ 94 LIBXSMM_DNN_CHANNEL_EXPECTVAL, 95 /* regular beta */ 96 LIBXSMM_DNN_CHANNEL_RCPSTDDEV, 97 /* variance */ 98 LIBXSMM_DNN_CHANNEL_VARIANCE, 99 /** general bias type */ 100 LIBXSMM_DNN_CHANNEL_SCALAR, 101 /** Labels */ 102 LIBXSMM_DNN_LABEL, 103 /** batch stats */ 104 LIBXSMM_DNN_BATCH_STATS, 105 LIBXSMM_DNN_MAX_STATS_FWD, 106 LIBXSMM_DNN_MAX_STATS_BWD, 107 LIBXSMM_DNN_MAX_STATS_UPD, 108 /** pooling mask */ 109 LIBXSMM_DNN_POOLING_MASK, 110 /** ReLU mask */ 111 LIBXSMM_DNN_RELU_MASK, 112 /** general type, if needed might cause API issues in copy in/out API */ 113 LIBXSMM_DNN_TENSOR, 114 115 /** regular input buffer */ 116 LIBXSMM_DNN_RNN_REGULAR_INPUT, 117 /** regular previous cell state buffer */ 118 LIBXSMM_DNN_RNN_REGULAR_CS_PREV, 119 /** regular previous hidden state buffer */ 120 LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV, 121 /** regular weight (LSTM: wi, wc, wf, wo) */ 122 LIBXSMM_DNN_RNN_REGULAR_WEIGHT, 123 /** regular recurrent weight (LSTM: ri, rc, rf, ro) */ 124 LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT, 125 /** regular weight (LSTM: wi, wc, wf, wo) */ 126 LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS, 127 /** regular recurrent weight (LSTM: ri, rc, rf, ro) */ 128 LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS, 129 /** regular bias (LSTM: bi, bc, bf, bo) */ 130 LIBXSMM_DNN_RNN_REGULAR_BIAS, 131 /** regular output cell state buffer */ 132 LIBXSMM_DNN_RNN_REGULAR_CS, 133 /** regular hidden state buffer */ 134 LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE, 135 /** gradient input buffer */ 136 LIBXSMM_DNN_RNN_GRADIENT_INPUT, 137 /** gradient previous cell state buffer */ 138 LIBXSMM_DNN_RNN_GRADIENT_CS_PREV, 139 /** gradient previous hidden state buffer */ 140 LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV, 141 /** gradient weight */ 142 LIBXSMM_DNN_RNN_GRADIENT_WEIGHT, 143 /** gradient recurrent weight */ 144 LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT, 145 /** gradient bias */ 146 LIBXSMM_DNN_RNN_GRADIENT_BIAS, 147 /** gradient output cell state buffer */ 148 LIBXSMM_DNN_RNN_GRADIENT_CS, 149 /** gradient hidden state buffer */ 150 LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE, 151 /** internal i buffer */ 152 LIBXSMM_DNN_RNN_INTERNAL_I, 153 /** internal f buffer */ 154 LIBXSMM_DNN_RNN_INTERNAL_F, 155 /** internal o buffer */ 156 LIBXSMM_DNN_RNN_INTERNAL_O, 157 /** internal ci buffer */ 158 LIBXSMM_DNN_RNN_INTERNAL_CI, 159 /** internal co buffer */ 160 LIBXSMM_DNN_RNN_INTERNAL_CO 161 } libxsmm_dnn_tensor_type; 162 163 /** layout descriptor to allow external data handling 164 outside of LIBXSMM */ 165 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_dnn_tensor_datalayout { 166 libxsmm_dnn_tensor_dimtype* dim_type; 167 unsigned int* dim_size; 168 unsigned int num_dims; 169 libxsmm_dnn_tensor_format format; /* format of activation buffer */ 170 libxsmm_dnn_datatype datatype; /* data type */ 171 libxsmm_dnn_tensor_type tensor_type; /* tensor type */ 172 } libxsmm_dnn_tensor_datalayout; 173 174 /** tensorlayout handling */ 175 LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_duplicate_tensor_datalayout(const libxsmm_dnn_tensor_datalayout* layout, libxsmm_dnn_err_t* status); 176 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_tensor_datalayout(libxsmm_dnn_tensor_datalayout* layout); 177 LIBXSMM_API unsigned int libxsmm_dnn_compare_tensor_datalayout(const libxsmm_dnn_tensor_datalayout* layout_a, const libxsmm_dnn_tensor_datalayout* layout_b, libxsmm_dnn_err_t* status); 178 LIBXSMM_API unsigned int libxsmm_dnn_get_tensor_size(const libxsmm_dnn_tensor_datalayout* layout, libxsmm_dnn_err_t* status); 179 LIBXSMM_API unsigned int libxsmm_dnn_get_tensor_elements(const libxsmm_dnn_tensor_datalayout* layout, libxsmm_dnn_err_t* status); 180 181 /** Create and manage buffers, filters and bias (non-NULL if successful) */ 182 LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_link_tensor(const libxsmm_dnn_tensor_datalayout* layout, const void* data, libxsmm_dnn_err_t* status); 183 LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_link_qtensor(const libxsmm_dnn_tensor_datalayout* layout, const void* data, const unsigned char exp, libxsmm_dnn_err_t* status); 184 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_set_tensor_data_ptr(libxsmm_dnn_tensor* tensor, const void* data); 185 LIBXSMM_API void* libxsmm_dnn_get_tensor_data_ptr(const libxsmm_dnn_tensor* tensor, libxsmm_dnn_err_t* status); 186 LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_get_tensor_datalayout(const libxsmm_dnn_tensor* tensor, libxsmm_dnn_err_t* status); 187 LIBXSMM_API unsigned char libxsmm_dnn_get_qtensor_scf(const libxsmm_dnn_tensor* tensor, libxsmm_dnn_err_t* status); 188 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_set_qtensor_scf(libxsmm_dnn_tensor* tensor, const unsigned char scf); 189 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_tensor(const libxsmm_dnn_tensor* tensor); 190 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_zero_tensor(const libxsmm_dnn_tensor* tensor); 191 192 /** 193 * Copy-in/out from a plain format such [N][C][H][W] or [N][H][W][C] 194 */ 195 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_copyin_tensor(const libxsmm_dnn_tensor* tensor, const void* data, const libxsmm_dnn_tensor_format in_format); 196 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_copyout_tensor(const libxsmm_dnn_tensor* tensor, void* data, const libxsmm_dnn_tensor_format out_format); 197 198 #endif /*LIBXSMM_DNN_TENSOR_H*/ 199 200