1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Alexander Heinecke (Intel Corp.)
10 ******************************************************************************/
11 #ifndef LIBXSMM_DNN_TENSOR_H
12 #define LIBXSMM_DNN_TENSOR_H
13 
14 #include "libxsmm_typedefs.h"
15 #include "libxsmm_dnn.h"
16 
17 /** Opaque handles which represents convolutions and LIBXSMM datatypes */
18 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_dnn_tensor libxsmm_dnn_tensor;
19 
20 typedef enum libxsmm_dnn_tensor_dimtype {
21   /** Mini-batch */
22   LIBXSMM_DNN_TENSOR_DIMTYPE_N,
23   /** Image Height */
24   LIBXSMM_DNN_TENSOR_DIMTYPE_H,
25   /** Image Width */
26   LIBXSMM_DNN_TENSOR_DIMTYPE_W,
27   /** channels or input channels */
28   LIBXSMM_DNN_TENSOR_DIMTYPE_C,
29   /** output channels */
30   LIBXSMM_DNN_TENSOR_DIMTYPE_K,
31   /** kernel height */
32   LIBXSMM_DNN_TENSOR_DIMTYPE_R,
33   /** kernel width */
34   LIBXSMM_DNN_TENSOR_DIMTYPE_S,
35   /** sequence lenth counter */
36   LIBXSMM_DNN_TENSOR_DIMTYPE_T,
37   /** channle group counter */
38   LIBXSMM_DNN_TENSOR_DIMTYPE_G,
39   /** general counter */
40   LIBXSMM_DNN_TENSOR_DIMTYPE_X
41 } libxsmm_dnn_tensor_dimtype;
42 
43 /** types of different buffers */
44 typedef enum libxsmm_dnn_tensor_type {
45   /** regular input buffer */
46   LIBXSMM_DNN_REGULAR_INPUT,
47   /** regular input buffer */
48   LIBXSMM_DNN_REGULAR_INPUT_ADD,
49   /** regular input buffer, transpose */
50   LIBXSMM_DNN_REGULAR_INPUT_TRANS,
51   /** gradient input buffer */
52   LIBXSMM_DNN_GRADIENT_INPUT,
53   /** gradient input buffer */
54   LIBXSMM_DNN_GRADIENT_INPUT_ADD,
55   /** regular output buffer */
56   LIBXSMM_DNN_REGULAR_OUTPUT,
57   /** gradient output buffer */
58   LIBXSMM_DNN_GRADIENT_OUTPUT,
59   /** general input type */
60   LIBXSMM_DNN_INPUT,
61   /** general output type */
62   LIBXSMM_DNN_OUTPUT,
63   /** general activation type */
64   LIBXSMM_DNN_ACTIVATION,
65   /* regular filter */
66   LIBXSMM_DNN_REGULAR_FILTER,
67   /* regular filter */
68   LIBXSMM_DNN_REGULAR_FILTER_TRANS,
69   /* gradient filter */
70   LIBXSMM_DNN_GRADIENT_FILTER,
71   /* master filter */
72   LIBXSMM_DNN_MASTER_FILTER,
73   /** general filter type */
74   LIBXSMM_DNN_FILTER,
75   /* regular bias */
76   LIBXSMM_DNN_REGULAR_CHANNEL_BIAS,
77   /* gradient bias */
78   LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS,
79   /* bias */
80   LIBXSMM_DNN_CHANNEL_BIAS,
81   /* regular beta */
82   LIBXSMM_DNN_REGULAR_CHANNEL_BETA,
83   /* gradient beta */
84   LIBXSMM_DNN_GRADIENT_CHANNEL_BETA,
85   /* beta */
86   LIBXSMM_DNN_CHANNEL_BETA,
87   /* regular gamma */
88   LIBXSMM_DNN_REGULAR_CHANNEL_GAMMA,
89   /* gradient gamma */
90   LIBXSMM_DNN_GRADIENT_CHANNEL_GAMMA,
91   /* Gamma */
92   LIBXSMM_DNN_CHANNEL_GAMMA,
93   /* regular beta */
94   LIBXSMM_DNN_CHANNEL_EXPECTVAL,
95   /* regular beta */
96   LIBXSMM_DNN_CHANNEL_RCPSTDDEV,
97   /* variance */
98   LIBXSMM_DNN_CHANNEL_VARIANCE,
99   /** general bias type */
100   LIBXSMM_DNN_CHANNEL_SCALAR,
101   /** Labels */
102   LIBXSMM_DNN_LABEL,
103   /** batch stats */
104   LIBXSMM_DNN_BATCH_STATS,
105   LIBXSMM_DNN_MAX_STATS_FWD,
106   LIBXSMM_DNN_MAX_STATS_BWD,
107   LIBXSMM_DNN_MAX_STATS_UPD,
108   /** pooling mask */
109   LIBXSMM_DNN_POOLING_MASK,
110   /** ReLU mask */
111   LIBXSMM_DNN_RELU_MASK,
112   /** general type, if needed might cause API issues in copy in/out API */
113   LIBXSMM_DNN_TENSOR,
114 
115   /** regular input buffer */
116   LIBXSMM_DNN_RNN_REGULAR_INPUT,
117   /** regular previous cell state buffer */
118   LIBXSMM_DNN_RNN_REGULAR_CS_PREV,
119   /** regular previous hidden state buffer */
120   LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE_PREV,
121   /** regular weight (LSTM: wi, wc, wf, wo) */
122   LIBXSMM_DNN_RNN_REGULAR_WEIGHT,
123   /** regular recurrent weight (LSTM: ri, rc, rf, ro) */
124   LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT,
125   /** regular weight (LSTM: wi, wc, wf, wo) */
126   LIBXSMM_DNN_RNN_REGULAR_WEIGHT_TRANS,
127   /** regular recurrent weight (LSTM: ri, rc, rf, ro) */
128   LIBXSMM_DNN_RNN_REGULAR_RECUR_WEIGHT_TRANS,
129   /** regular bias (LSTM: bi, bc, bf, bo) */
130   LIBXSMM_DNN_RNN_REGULAR_BIAS,
131   /** regular output cell state buffer */
132   LIBXSMM_DNN_RNN_REGULAR_CS,
133   /** regular hidden state buffer */
134   LIBXSMM_DNN_RNN_REGULAR_HIDDEN_STATE,
135   /** gradient input buffer */
136   LIBXSMM_DNN_RNN_GRADIENT_INPUT,
137   /** gradient previous cell state buffer */
138   LIBXSMM_DNN_RNN_GRADIENT_CS_PREV,
139   /** gradient previous hidden state buffer */
140   LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE_PREV,
141   /** gradient weight */
142   LIBXSMM_DNN_RNN_GRADIENT_WEIGHT,
143   /** gradient recurrent weight */
144   LIBXSMM_DNN_RNN_GRADIENT_RECUR_WEIGHT,
145   /** gradient bias */
146   LIBXSMM_DNN_RNN_GRADIENT_BIAS,
147   /** gradient output cell state buffer */
148   LIBXSMM_DNN_RNN_GRADIENT_CS,
149   /** gradient hidden state buffer */
150   LIBXSMM_DNN_RNN_GRADIENT_HIDDEN_STATE,
151   /** internal i buffer */
152   LIBXSMM_DNN_RNN_INTERNAL_I,
153   /** internal f buffer */
154   LIBXSMM_DNN_RNN_INTERNAL_F,
155   /** internal o buffer */
156   LIBXSMM_DNN_RNN_INTERNAL_O,
157   /** internal ci buffer */
158   LIBXSMM_DNN_RNN_INTERNAL_CI,
159   /** internal co buffer */
160   LIBXSMM_DNN_RNN_INTERNAL_CO
161 } libxsmm_dnn_tensor_type;
162 
163 /** layout descriptor to allow external data handling
164     outside of LIBXSMM */
165 LIBXSMM_EXTERN_C typedef struct LIBXSMM_RETARGETABLE libxsmm_dnn_tensor_datalayout {
166   libxsmm_dnn_tensor_dimtype* dim_type;
167   unsigned int* dim_size;
168   unsigned int num_dims;
169   libxsmm_dnn_tensor_format format;                /* format of activation buffer */
170   libxsmm_dnn_datatype datatype;                   /* data type */
171   libxsmm_dnn_tensor_type tensor_type;             /* tensor type */
172 } libxsmm_dnn_tensor_datalayout;
173 
174 /** tensorlayout handling */
175 LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_duplicate_tensor_datalayout(const libxsmm_dnn_tensor_datalayout* layout, libxsmm_dnn_err_t* status);
176 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_tensor_datalayout(libxsmm_dnn_tensor_datalayout* layout);
177 LIBXSMM_API unsigned int libxsmm_dnn_compare_tensor_datalayout(const libxsmm_dnn_tensor_datalayout* layout_a, const libxsmm_dnn_tensor_datalayout* layout_b, libxsmm_dnn_err_t* status);
178 LIBXSMM_API unsigned int libxsmm_dnn_get_tensor_size(const libxsmm_dnn_tensor_datalayout* layout, libxsmm_dnn_err_t* status);
179 LIBXSMM_API unsigned int libxsmm_dnn_get_tensor_elements(const libxsmm_dnn_tensor_datalayout* layout, libxsmm_dnn_err_t* status);
180 
181 /** Create and manage buffers, filters and bias (non-NULL if successful) */
182 LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_link_tensor(const libxsmm_dnn_tensor_datalayout* layout, const void* data, libxsmm_dnn_err_t* status);
183 LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_link_qtensor(const libxsmm_dnn_tensor_datalayout* layout, const void* data, const unsigned char exp, libxsmm_dnn_err_t* status);
184 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_set_tensor_data_ptr(libxsmm_dnn_tensor* tensor, const void* data);
185 LIBXSMM_API void* libxsmm_dnn_get_tensor_data_ptr(const libxsmm_dnn_tensor* tensor, libxsmm_dnn_err_t* status);
186 LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_get_tensor_datalayout(const libxsmm_dnn_tensor* tensor, libxsmm_dnn_err_t* status);
187 LIBXSMM_API unsigned char libxsmm_dnn_get_qtensor_scf(const libxsmm_dnn_tensor* tensor, libxsmm_dnn_err_t* status);
188 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_set_qtensor_scf(libxsmm_dnn_tensor* tensor, const unsigned char scf);
189 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_tensor(const libxsmm_dnn_tensor* tensor);
190 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_zero_tensor(const libxsmm_dnn_tensor* tensor);
191 
192 /**
193  * Copy-in/out from a plain format such [N][C][H][W] or [N][H][W][C]
194  */
195 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_copyin_tensor(const libxsmm_dnn_tensor* tensor, const void* data, const libxsmm_dnn_tensor_format in_format);
196 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_copyout_tensor(const libxsmm_dnn_tensor* tensor, void* data, const libxsmm_dnn_tensor_format out_format);
197 
198 #endif /*LIBXSMM_DNN_TENSOR_H*/
199 
200