1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved. *
3 * This file is part of the LIBXSMM library. *
4 * *
5 * For information on the license, see the LICENSE file. *
6 * Further information: https://github.com/hfp/libxsmm/ *
7 * SPDX-License-Identifier: BSD-3-Clause *
8 ******************************************************************************/
9 /* Alexander Heinecke, Sasikanth Avancha (Intel Corp.)
10 ******************************************************************************/
11 #include "libxsmm_dnn_fullyconnected_backward_weight_update.h"
12 #include "libxsmm_dnn_fullyconnected_forward.h"
13 #include "libxsmm_main.h"
14
libxsmm_dnn_create_fullyconnected(libxsmm_dnn_fullyconnected_desc fullyconnected_desc,libxsmm_dnn_err_t * status)15 LIBXSMM_API libxsmm_dnn_fullyconnected* libxsmm_dnn_create_fullyconnected(libxsmm_dnn_fullyconnected_desc fullyconnected_desc, libxsmm_dnn_err_t* status) {
16 libxsmm_dnn_fullyconnected* handle = 0;
17 const libxsmm_trans_descriptor* tr_desc = 0;
18 libxsmm_descriptor_blob blob;
19
20 /* init libxsmm */
21 LIBXSMM_INIT
22
23 if ( ((fullyconnected_desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (fullyconnected_desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) ||
24 ((fullyconnected_desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (fullyconnected_desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) ||
25 ((fullyconnected_desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (fullyconnected_desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) ) {
26 handle = (libxsmm_dnn_fullyconnected*)malloc(sizeof(libxsmm_dnn_fullyconnected));
27
28 if (0 != handle) {
29 *status = LIBXSMM_DNN_SUCCESS;
30 /* zero entire content; not only safer but also sets data and code pointers to NULL */
31 memset(handle, 0, sizeof(*handle));
32 /* let's make the description persistent */
33 handle->desc = fullyconnected_desc;
34 /* @TODO perhaps we need a better switch here */
35 if ( (handle->desc.buffer_format == LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) && (handle->desc.filter_format == LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED) ) {
36 handle->bk = handle->desc.bk;
37 handle->bn = handle->desc.bn;
38 handle->bc = handle->desc.bc;
39
40 if ( handle->desc.N % handle->bn != 0 ) {
41 handle->bn = handle->desc.N;
42 *status = LIBXSMM_DNN_WARN_FC_SUBOPTIMAL_N_BLOCKING;
43 }
44 if ( handle->desc.C % handle->bc != 0 ) {
45 handle->bc = handle->desc.C;
46 *status = LIBXSMM_DNN_WARN_FC_SUBOPTIMAL_C_BLOCKING;
47 }
48 if ( handle->desc.K % handle->bk != 0 ) {
49 handle->bk = handle->desc.K;
50 *status = LIBXSMM_DNN_WARN_FC_SUBOPTIMAL_K_BLOCKING;
51 }
52 if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) {
53 #if 0
54 handle->fwd_bf = atoi(getenv("FWD_BF"));
55 handle->bwd_bf = atoi(getenv("BWD_BF"));
56 handle->upd_bf = atoi(getenv("UPD_BF"));
57 handle->fwd_2d_blocking = atoi(getenv("FWD_2D_BLOCKING"));
58 handle->bwd_2d_blocking = atoi(getenv("BWD_2D_BLOCKING"));
59 handle->upd_2d_blocking = atoi(getenv("UPD_2D_BLOCKING"));
60 handle->fwd_row_teams = atoi(getenv("FWD_ROW_TEAMS"));
61 handle->fwd_column_teams = atoi(getenv("FWD_COLUMN_TEAMS"));
62 handle->bwd_row_teams = atoi(getenv("BWD_ROW_TEAMS"));
63 handle->bwd_column_teams = atoi(getenv("BWD_COLUMN_TEAMS"));
64 handle->upd_row_teams = atoi(getenv("UPD_ROW_TEAMS"));
65 handle->upd_column_teams = atoi(getenv("UPD_COLUMN_TEAMS"));
66 handle->ifm_subtasks = atoi(getenv("IFM_SUBTASKS"));
67 handle->ofm_subtasks = atoi(getenv("OFM_SUBTASKS"));
68 #else
69 /* Initialize with default values */
70 handle->fwd_bf = 1;
71 handle->bwd_bf = 1;
72 handle->upd_bf = 1;
73 handle->fwd_2d_blocking = 0;
74 handle->bwd_2d_blocking = 0;
75 handle->upd_2d_blocking = 0;
76 handle->fwd_row_teams = 1;
77 handle->fwd_column_teams = 1;
78 handle->bwd_row_teams = 1;
79 handle->bwd_column_teams = 1;
80 handle->upd_row_teams = 1;
81 handle->upd_column_teams = 1;
82 handle->ifm_subtasks = 1;
83 handle->ofm_subtasks = 1;
84
85 if (handle->desc.C == 100 && handle->desc.K == 1024 && handle->desc.threads == 28) {
86 handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
87 handle->fwd_2d_blocking = 1;
88 handle->fwd_row_teams = 14;
89 handle->fwd_column_teams = 2;
90 handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
91 handle->bwd_2d_blocking = 0;
92 handle->bwd_row_teams = 1;
93 handle->bwd_column_teams = 1;
94 handle->upd_bf = ((handle->desc.N/handle->bn) % 14 == 0) ? 14 : 1;
95 handle->upd_2d_blocking = 0;
96 handle->upd_row_teams = 1;
97 handle->upd_column_teams = 1;
98 handle->ifm_subtasks = 1/*((handle->bc % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
99 handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
100 }
101
102 if (handle->desc.C == 1024 && handle->desc.K == 1024 && handle->desc.threads == 28) {
103 handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
104 handle->fwd_2d_blocking = 1;
105 handle->fwd_row_teams = 7;
106 handle->fwd_column_teams = 4;
107 handle->bwd_bf = ((handle->desc.K/handle->bk) % 8 == 0) ? 8 : 1;
108 handle->bwd_2d_blocking = 0;
109 handle->bwd_row_teams = 7;
110 handle->bwd_column_teams = 4;
111 handle->upd_bf = ((handle->desc.N/handle->bn) % 14 == 0) ? 14 : 1;
112 handle->upd_2d_blocking = 0;
113 handle->upd_row_teams = 7;
114 handle->upd_column_teams = 4;
115 handle->ifm_subtasks = ((handle->bc % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1;
116 handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
117 }
118
119 if (handle->desc.C == 512 && handle->desc.K == 512 && handle->desc.threads == 28) {
120 handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
121 handle->fwd_2d_blocking = 0;
122 handle->fwd_row_teams = 1;
123 handle->fwd_column_teams = 1;
124 handle->bwd_bf = ((handle->desc.K/handle->bk) % 4 == 0) ? 4 : 1;
125 handle->bwd_2d_blocking = 0;
126 handle->bwd_row_teams = 1;
127 handle->bwd_column_teams = 1;
128 handle->upd_bf = ((handle->desc.N/handle->bn) % 14 == 0) ? 14 : 1;
129 handle->upd_2d_blocking = 0;
130 handle->upd_row_teams = 1;
131 handle->upd_column_teams = 1;
132 handle->ifm_subtasks = ((handle->bc % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1;
133 handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
134 }
135
136 if (handle->desc.C == 1024 && handle->desc.K == 1 && handle->desc.threads == 28) {
137 handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
138 handle->fwd_2d_blocking = 0;
139 handle->fwd_row_teams = 1;
140 handle->fwd_column_teams = 1;
141 handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
142 handle->bwd_2d_blocking = 1;
143 handle->bwd_row_teams = 14;
144 handle->bwd_column_teams = 2;
145 handle->upd_bf = ((handle->desc.N/handle->bn) % 2 == 0) ? 2 : 1;
146 handle->upd_2d_blocking = 0;
147 handle->upd_row_teams = 1;
148 handle->upd_column_teams = 1;
149 handle->ifm_subtasks = ((handle->bc % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1;
150 handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
151 }
152
153 if (handle->desc.C == 1024 && handle->desc.K == 1024 && handle->desc.threads == 20) {
154 handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
155 handle->fwd_2d_blocking = 0;
156 handle->fwd_row_teams = 5;
157 handle->fwd_column_teams = 4;
158 handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
159 handle->bwd_2d_blocking = 1;
160 handle->bwd_row_teams = 5;
161 handle->bwd_column_teams = 4;
162 handle->upd_bf = ((handle->desc.N/handle->bn) % 15 == 0) ? 15 : 1;
163 handle->upd_2d_blocking = 0;
164 handle->upd_row_teams = 5;
165 handle->upd_column_teams = 4;
166 handle->ifm_subtasks = 1/*((handle->bc % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
167 handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
168 }
169
170 if (handle->desc.C == 100 && handle->desc.K == 1024 && handle->desc.threads == 20) {
171 handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
172 handle->fwd_2d_blocking = 1;
173 handle->fwd_row_teams = 5;
174 handle->fwd_column_teams = 4;
175 handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
176 handle->bwd_2d_blocking = 0;
177 handle->bwd_row_teams = 1;
178 handle->bwd_column_teams = 1;
179 handle->upd_bf = ((handle->desc.N/handle->bn) % 9 == 0) ? 9 : 1;
180 handle->upd_2d_blocking = 0;
181 handle->upd_row_teams = 1;
182 handle->upd_column_teams = 1;
183 handle->ifm_subtasks = 1/*((handle->bc % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
184 handle->ofm_subtasks = ((handle->bk % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1;
185 }
186
187 if (handle->desc.C == 1024 && handle->desc.K == 1024 && handle->desc.threads == 24) {
188 handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
189 handle->fwd_2d_blocking = 0;
190 handle->fwd_row_teams = 6;
191 handle->fwd_column_teams = 4;
192 handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
193 handle->bwd_2d_blocking = 0;
194 handle->bwd_row_teams = 6;
195 handle->bwd_column_teams = 4;
196 handle->upd_bf = ((handle->desc.N/handle->bn) % 15 == 0) ? 15 : 1;
197 handle->upd_2d_blocking = 0;
198 handle->upd_row_teams = 6;
199 handle->upd_column_teams = 4;
200 handle->ifm_subtasks = ((handle->bc % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1;
201 handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
202 }
203 if (handle->desc.C == 100 && handle->desc.K == 1024 && handle->desc.threads == 24) {
204 handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
205 handle->fwd_2d_blocking = 0;
206 handle->fwd_row_teams = 5;
207 handle->fwd_column_teams = 4;
208 handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
209 handle->bwd_2d_blocking = 1;
210 handle->bwd_row_teams = 12;
211 handle->bwd_column_teams = 2;
212 handle->upd_bf = ((handle->desc.N/handle->bn) % 15 == 0) ? 15 : 1;
213 handle->upd_2d_blocking = 0;
214 handle->upd_row_teams = 5;
215 handle->upd_column_teams = 4;
216 handle->ifm_subtasks = 1/*((handle->bc % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
217 handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
218 }
219 if (handle->desc.C == 512 && handle->desc.K == 512 && handle->desc.threads == 24) {
220 handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
221 handle->fwd_2d_blocking = 0;
222 handle->fwd_row_teams = 5;
223 handle->fwd_column_teams = 4;
224 handle->bwd_bf = ((handle->desc.K/handle->bk) % 4 == 0) ? 4 : 1;
225 handle->bwd_2d_blocking = 0;
226 handle->bwd_row_teams = 5;
227 handle->bwd_column_teams = 4;
228 handle->upd_bf = ((handle->desc.N/handle->bn) % 15 == 0) ? 15 : 1;
229 handle->upd_2d_blocking = 0;
230 handle->upd_row_teams = 5;
231 handle->upd_column_teams = 4;
232 handle->ifm_subtasks = ((handle->bc % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1;
233 handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
234 }
235 if (handle->desc.C == 512 && handle->desc.K == 512 && handle->desc.threads == 20) {
236 handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
237 handle->fwd_2d_blocking = 1;
238 handle->fwd_row_teams = 5;
239 handle->fwd_column_teams = 4;
240 handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
241 handle->bwd_2d_blocking = 0;
242 handle->bwd_row_teams = 1;
243 handle->bwd_column_teams = 1;
244 handle->upd_bf = ((handle->desc.N/handle->bn) % 15 == 0) ? 15 : 1;
245 handle->upd_2d_blocking = 0;
246 handle->upd_row_teams = 1;
247 handle->upd_column_teams = 1;
248 handle->ifm_subtasks = ((handle->bc % 4 == 0) && (handle->upd_2d_blocking == 0)) ? 4 : 1;
249 handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
250 }
251 if (handle->desc.C == 1024 && handle->desc.K == 1 && handle->desc.threads == 24) {
252 handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
253 handle->fwd_2d_blocking = 0;
254 handle->fwd_row_teams = 5;
255 handle->fwd_column_teams = 4;
256 handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
257 handle->bwd_2d_blocking = 0;
258 handle->bwd_row_teams = 5;
259 handle->bwd_column_teams = 4;
260 handle->upd_bf = 1/*((handle->desc.N/handle->bn) % 1 == 0) ? 1 : 1*/;
261 handle->upd_2d_blocking = 0;
262 handle->upd_row_teams = 5;
263 handle->upd_column_teams = 4;
264 handle->ifm_subtasks = ((handle->bc % 4 == 0) && (handle->upd_2d_blocking == 0)) ? 4 : 1;
265 handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
266 }
267 if (handle->desc.C == 1024 && handle->desc.K == 1 && handle->desc.threads == 20) {
268 handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
269 handle->fwd_2d_blocking = 0;
270 handle->fwd_row_teams = 6;
271 handle->fwd_column_teams = 4;
272 handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
273 handle->bwd_2d_blocking = 1;
274 handle->bwd_row_teams = 5;
275 handle->bwd_column_teams = 4;
276 handle->upd_bf = 1/*((handle->desc.N/handle->bn) % 1 == 0) ? 1 : 1*/;
277 handle->upd_2d_blocking = 0;
278 handle->upd_row_teams = 6;
279 handle->upd_column_teams = 4;
280 handle->ifm_subtasks = 1/*((handle->bc % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
281 handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
282 }
283 #endif
284 } else if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) {
285 #if 0
286 handle->fwd_bf = atoi(getenv("FWD_BF"));
287 handle->bwd_bf = atoi(getenv("BWD_BF"));
288 handle->upd_bf = atoi(getenv("UPD_BF"));
289 handle->fwd_2d_blocking = atoi(getenv("FWD_2D_BLOCKING"));
290 handle->bwd_2d_blocking = atoi(getenv("BWD_2D_BLOCKING"));
291 handle->upd_2d_blocking = atoi(getenv("UPD_2D_BLOCKING"));
292 handle->fwd_row_teams = atoi(getenv("FWD_ROW_TEAMS"));
293 handle->fwd_column_teams = atoi(getenv("FWD_COLUMN_TEAMS"));
294 handle->bwd_row_teams = atoi(getenv("BWD_ROW_TEAMS"));
295 handle->bwd_column_teams = atoi(getenv("BWD_COLUMN_TEAMS"));
296 handle->upd_row_teams = atoi(getenv("UPD_ROW_TEAMS"));
297 handle->upd_column_teams = atoi(getenv("UPD_COLUMN_TEAMS"));
298 handle->ifm_subtasks = atoi(getenv("IFM_SUBTASKS"));
299 handle->ofm_subtasks = atoi(getenv("OFM_SUBTASKS"));
300 #else
301 /* Initialize with default values */
302 handle->fwd_bf = 1;
303 handle->bwd_bf = 1;
304 handle->upd_bf = 1;
305 handle->fwd_2d_blocking = 0;
306 handle->bwd_2d_blocking = 0;
307 handle->upd_2d_blocking = 0;
308 handle->fwd_row_teams = 1;
309 handle->fwd_column_teams = 1;
310 handle->bwd_row_teams = 1;
311 handle->bwd_column_teams = 1;
312 handle->upd_row_teams = 1;
313 handle->upd_column_teams = 1;
314 handle->ifm_subtasks = 1;
315 handle->ofm_subtasks = 1;
316
317 if (handle->desc.C == 100 && handle->desc.K == 1024 && handle->desc.threads == 28) {
318 handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
319 handle->fwd_2d_blocking = 1;
320 handle->fwd_row_teams = 14;
321 handle->fwd_column_teams = 2;
322 handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
323 handle->bwd_2d_blocking = 0;
324 handle->bwd_row_teams = 1;
325 handle->bwd_column_teams = 1;
326 handle->upd_bf = ((handle->desc.N/handle->bn) % 14 == 0) ? 14 : 1;
327 handle->upd_2d_blocking = 0;
328 handle->upd_row_teams = 1;
329 handle->upd_column_teams = 1;
330 handle->ifm_subtasks = 1/*((handle->bc % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
331 handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
332 }
333
334 if (handle->desc.C == 1024 && handle->desc.K == 1024 && handle->desc.threads == 28) {
335 handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
336 handle->fwd_2d_blocking = 1;
337 handle->fwd_row_teams = 7;
338 handle->fwd_column_teams = 4;
339 handle->bwd_bf = ((handle->desc.K/handle->bk) % 8 == 0) ? 8 : 1;
340 handle->bwd_2d_blocking = 0;
341 handle->bwd_row_teams = 7;
342 handle->bwd_column_teams = 4;
343 handle->upd_bf = ((handle->desc.N/handle->bn) % 14 == 0) ? 14 : 1;
344 handle->upd_2d_blocking = 0;
345 handle->upd_row_teams = 7;
346 handle->upd_column_teams = 4;
347 handle->ifm_subtasks = ((handle->bc % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1;
348 handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
349 }
350
351 if (handle->desc.C == 512 && handle->desc.K == 512 && handle->desc.threads == 28) {
352 handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
353 handle->fwd_2d_blocking = 0;
354 handle->fwd_row_teams = 1;
355 handle->fwd_column_teams = 1;
356 handle->bwd_bf = ((handle->desc.K/handle->bk) % 4 == 0) ? 4 : 1;
357 handle->bwd_2d_blocking = 0;
358 handle->bwd_row_teams = 1;
359 handle->bwd_column_teams = 1;
360 handle->upd_bf = ((handle->desc.N/handle->bn) % 14 == 0) ? 14 : 1;
361 handle->upd_2d_blocking = 0;
362 handle->upd_row_teams = 1;
363 handle->upd_column_teams = 1;
364 handle->ifm_subtasks = ((handle->bc % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1;
365 handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
366 }
367
368 if (handle->desc.C == 1024 && handle->desc.K == 1 && handle->desc.threads == 28) {
369 handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
370 handle->fwd_2d_blocking = 0;
371 handle->fwd_row_teams = 1;
372 handle->fwd_column_teams = 1;
373 handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
374 handle->bwd_2d_blocking = 1;
375 handle->bwd_row_teams = 14;
376 handle->bwd_column_teams = 2;
377 handle->upd_bf = ((handle->desc.N/handle->bn) % 2 == 0) ? 2 : 1;
378 handle->upd_2d_blocking = 0;
379 handle->upd_row_teams = 1;
380 handle->upd_column_teams = 1;
381 handle->ifm_subtasks = ((handle->bc % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1;
382 handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
383 }
384
385 if (handle->desc.C == 1024 && handle->desc.K == 1024 && handle->desc.threads == 20) {
386 handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
387 handle->fwd_2d_blocking = 0;
388 handle->fwd_row_teams = 5;
389 handle->fwd_column_teams = 4;
390 handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
391 handle->bwd_2d_blocking = 1;
392 handle->bwd_row_teams = 5;
393 handle->bwd_column_teams = 4;
394 handle->upd_bf = ((handle->desc.N/handle->bn) % 15 == 0) ? 15 : 1;
395 handle->upd_2d_blocking = 0;
396 handle->upd_row_teams = 5;
397 handle->upd_column_teams = 4;
398 handle->ifm_subtasks = 1/*((handle->bc % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
399 handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
400 }
401
402 if (handle->desc.C == 100 && handle->desc.K == 1024 && handle->desc.threads == 20) {
403 handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
404 handle->fwd_2d_blocking = 1;
405 handle->fwd_row_teams = 5;
406 handle->fwd_column_teams = 4;
407 handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
408 handle->bwd_2d_blocking = 0;
409 handle->bwd_row_teams = 1;
410 handle->bwd_column_teams = 1;
411 handle->upd_bf = ((handle->desc.N/handle->bn) % 9 == 0) ? 9 : 1;
412 handle->upd_2d_blocking = 0;
413 handle->upd_row_teams = 1;
414 handle->upd_column_teams = 1;
415 handle->ifm_subtasks = 1/*((handle->bc % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
416 handle->ofm_subtasks = ((handle->bk % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1;
417 }
418
419 if (handle->desc.C == 1024 && handle->desc.K == 1024 && handle->desc.threads == 24) {
420 handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
421 handle->fwd_2d_blocking = 0;
422 handle->fwd_row_teams = 6;
423 handle->fwd_column_teams = 4;
424 handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
425 handle->bwd_2d_blocking = 0;
426 handle->bwd_row_teams = 6;
427 handle->bwd_column_teams = 4;
428 handle->upd_bf = ((handle->desc.N/handle->bn) % 15 == 0) ? 15 : 1;
429 handle->upd_2d_blocking = 0;
430 handle->upd_row_teams = 6;
431 handle->upd_column_teams = 4;
432 handle->ifm_subtasks = ((handle->bc % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1;
433 handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
434 }
435 if (handle->desc.C == 100 && handle->desc.K == 1024 && handle->desc.threads == 24) {
436 handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
437 handle->fwd_2d_blocking = 0;
438 handle->fwd_row_teams = 5;
439 handle->fwd_column_teams = 4;
440 handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
441 handle->bwd_2d_blocking = 1;
442 handle->bwd_row_teams = 12;
443 handle->bwd_column_teams = 2;
444 handle->upd_bf = ((handle->desc.N/handle->bn) % 15 == 0) ? 15 : 1;
445 handle->upd_2d_blocking = 0;
446 handle->upd_row_teams = 5;
447 handle->upd_column_teams = 4;
448 handle->ifm_subtasks = 1/*((handle->bc % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
449 handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
450 }
451 if (handle->desc.C == 512 && handle->desc.K == 512 && handle->desc.threads == 24) {
452 handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
453 handle->fwd_2d_blocking = 0;
454 handle->fwd_row_teams = 5;
455 handle->fwd_column_teams = 4;
456 handle->bwd_bf = ((handle->desc.K/handle->bk) % 4 == 0) ? 4 : 1;
457 handle->bwd_2d_blocking = 0;
458 handle->bwd_row_teams = 5;
459 handle->bwd_column_teams = 4;
460 handle->upd_bf = ((handle->desc.N/handle->bn) % 15 == 0) ? 15 : 1;
461 handle->upd_2d_blocking = 0;
462 handle->upd_row_teams = 5;
463 handle->upd_column_teams = 4;
464 handle->ifm_subtasks = ((handle->bc % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1;
465 handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
466 }
467 if (handle->desc.C == 512 && handle->desc.K == 512 && handle->desc.threads == 20) {
468 handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
469 handle->fwd_2d_blocking = 1;
470 handle->fwd_row_teams = 5;
471 handle->fwd_column_teams = 4;
472 handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
473 handle->bwd_2d_blocking = 0;
474 handle->bwd_row_teams = 1;
475 handle->bwd_column_teams = 1;
476 handle->upd_bf = ((handle->desc.N/handle->bn) % 15 == 0) ? 15 : 1;
477 handle->upd_2d_blocking = 0;
478 handle->upd_row_teams = 1;
479 handle->upd_column_teams = 1;
480 handle->ifm_subtasks = ((handle->bc % 4 == 0) && (handle->upd_2d_blocking == 0)) ? 4 : 1;
481 handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
482 }
483 if (handle->desc.C == 1024 && handle->desc.K == 1 && handle->desc.threads == 24) {
484 handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
485 handle->fwd_2d_blocking = 0;
486 handle->fwd_row_teams = 5;
487 handle->fwd_column_teams = 4;
488 handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
489 handle->bwd_2d_blocking = 0;
490 handle->bwd_row_teams = 5;
491 handle->bwd_column_teams = 4;
492 handle->upd_bf = 1/*((handle->desc.N/handle->bn) % 1 == 0) ? 1 : 1*/;
493 handle->upd_2d_blocking = 0;
494 handle->upd_row_teams = 5;
495 handle->upd_column_teams = 4;
496 handle->ifm_subtasks = ((handle->bc % 4 == 0) && (handle->upd_2d_blocking == 0)) ? 4 : 1;
497 handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
498 }
499 if (handle->desc.C == 1024 && handle->desc.K == 1 && handle->desc.threads == 20) {
500 handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
501 handle->fwd_2d_blocking = 0;
502 handle->fwd_row_teams = 6;
503 handle->fwd_column_teams = 4;
504 handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
505 handle->bwd_2d_blocking = 1;
506 handle->bwd_row_teams = 5;
507 handle->bwd_column_teams = 4;
508 handle->upd_bf = 1/*((handle->desc.N/handle->bn) % 1 == 0) ? 1 : 1*/;
509 handle->upd_2d_blocking = 0;
510 handle->upd_row_teams = 6;
511 handle->upd_column_teams = 4;
512 handle->ifm_subtasks = 1/*((handle->bc % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
513 handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
514 }
515 #endif
516 }
517 } else {
518 /* check that we cannot fuse */
519 if ( handle->desc.fuse_ops != LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) {
520 free( handle );
521 *status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION;
522 return 0;
523 }
524
525 /* we need to compute the memory layout given the */
526 if ( (handle->desc.C % 16 == 0) && (handle->desc.K % 16 == 0) ) {
527 if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) {
528 *status = libxsmm_dnn_get_feature_map_blocks( handle->desc.C, handle->desc.K,
529 &(handle->ifmblock), &(handle->ofmblock), &(handle->fm_lp_block),
530 LIBXSMM_DNN_DATATYPE_F32, LIBXSMM_DNN_DATATYPE_F32 );
531 } else if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) {
532 *status = libxsmm_dnn_get_feature_map_blocks( handle->desc.C, handle->desc.K,
533 &(handle->ifmblock), &(handle->ofmblock), &(handle->fm_lp_block),
534 handle->desc.datatype_in, handle->desc.datatype_out );
535 } else {
536 /* should not happen, not implemented */
537 }
538 } else if ( (handle->desc.C % 64 == 0) && (handle->desc.K == 1000) ) {
539 /* @TODO this a hack for the last FC layer */
540 handle->ifmblock = 64;
541 handle->fm_lp_block = 1;
542 handle->ofmblock = 10;
543 } else if ( (handle->desc.C % 16 == 0) && (handle->desc.K == 1000) ) {
544 /* @TODO this a hack for the last FC layer */
545 handle->ifmblock = 16;
546 handle->fm_lp_block = 1;
547 handle->ofmblock = 10;
548 } else {
549 *status = LIBXSMM_DNN_ERR_CREATE_HANDLE;
550 free( handle );
551 return 0;
552 }
553 /* compute the outer blocks */
554 handle->blocksifm = handle->desc.C / handle->ifmblock;
555 handle->blocksofm = handle->desc.K / handle->ofmblock;
556 }
557 /* create barrier */
558 handle->barrier = libxsmm_barrier_create(handle->desc.threads, 1);
559
560 /* calculate scratch size */
561 if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) {
562 handle->scratch_size = sizeof(float) * ( ( (size_t)handle->desc.C * (size_t)handle->desc.N ) + ( (size_t)handle->desc.C * (size_t)handle->desc.K ) );
563 } else if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) {
564 /* Let's allocate maximum required scratch */
565 size_t size_fwd = sizeof(float) * handle->desc.K * handle->desc.N;
566 /* In case of K = 1 we pad A and B to "bk=2" */
567 size_t size_bwd = (handle->desc.K != 1) ? ( sizeof(float) * handle->desc.C * handle->desc.N + sizeof(libxsmm_bfloat16) * handle->desc.C * handle->desc.K ) : ( sizeof(float) * handle->desc.C * handle->desc.N + sizeof(libxsmm_bfloat16) * handle->desc.C * 2 + sizeof(libxsmm_bfloat16) * 2 * handle->desc.N );
568 size_t size_upd = sizeof(float) * handle->desc.C * handle->desc.K + sizeof(libxsmm_bfloat16) * handle->desc.threads * handle->bk * handle->bc + sizeof(libxsmm_bfloat16) * (handle->desc.N * (handle->desc.C + handle->desc.K));
569 handle->scratch_size = LIBXSMM_MAX(LIBXSMM_MAX(size_fwd, size_bwd), size_upd);
570 handle->doutput_scratch_mark = handle->scratch_size;
571 handle->scratch_size += 2 * sizeof(libxsmm_bfloat16) * handle->desc.N * handle->desc.K;
572 } else {
573 handle->scratch_size = sizeof(float) * ( (((size_t)handle->desc.C + (size_t)handle->desc.K) * (size_t)handle->desc.N) + ((size_t)handle->desc.C * (size_t)handle->desc.K) );
574 }
575 /* create code pointers in some special cases */
576 if ( ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) > 0) && ((handle->desc.filter_format & LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED) > 0) ) {
577 if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) {
578 float alpha = 1.0f;
579 /* beta is set to 1 for ncnc kcck format because ifm is split into 2 blocks */
580 float beta = 1.0f;
581 float zerobeta = 0.0f;
582 int updflags = LIBXSMM_GEMM_FLAGS( 'N', 'T' );
583 /* For UPD kernels we consider subtasking... */
584 libxsmm_blasint M = handle->bk/handle->ofm_subtasks;
585 libxsmm_blasint N = handle->bc/handle->ifm_subtasks;
586
587 libxsmm_blasint lda = (libxsmm_blasint)handle->bk;
588 libxsmm_blasint ldb = (libxsmm_blasint)handle->bc;
589 libxsmm_blasint ldc = (libxsmm_blasint)handle->bk;
590
591 handle->gemm_fwd.xgemm.smrs = libxsmm_smmdispatch_reducebatch_strd(handle->bk, handle->bn, handle->bc, handle->bk*handle->bc*sizeof(float), handle->bc*handle->bn*sizeof(float), &lda, &ldb, &ldc, &alpha, &beta, NULL, NULL);
592 handle->gemm_fwd2.xgemm.smrs = libxsmm_smmdispatch_reducebatch_strd(handle->bk, handle->bn, handle->bc, handle->bk*handle->bc*sizeof(float), handle->bc*handle->bn*sizeof(float), &lda, &ldb, &ldc, &alpha, &zerobeta, NULL, NULL);
593 handle->gemm_bwd.xgemm.smrs = libxsmm_smmdispatch_reducebatch_strd(handle->bc, handle->bn, handle->bk, handle->bk*handle->bc*sizeof(float), handle->bk*handle->bn*sizeof(float), &ldb, &lda, &ldb, &alpha, &beta, NULL, NULL);
594 handle->gemm_bwd2.xgemm.smrs = libxsmm_smmdispatch_reducebatch_strd(handle->bc, handle->bn, handle->bk, handle->bk*handle->bc*sizeof(float), handle->bk*handle->bn*sizeof(float), &ldb, &lda, &ldb, &alpha, &zerobeta, NULL, NULL);
595
596 /* Transpose kernel used for weight transpose in bwd pass */
597 tr_desc = libxsmm_trans_descriptor_init(&blob, sizeof(float), handle->bk, handle->bc, handle->bc);
598 handle->tr_kernel = libxsmm_dispatch_trans(tr_desc);
599
600 /* update has different LDs */
601 lda = (libxsmm_blasint)handle->bk;
602 ldb = (libxsmm_blasint)handle->bc;
603 ldc = (libxsmm_blasint)handle->bk;
604 handle->gemm_upd.xgemm.smrs = libxsmm_smmdispatch_reducebatch_strd(M, N, handle->bn, handle->desc.K*handle->bn*sizeof(float), handle->desc.C*handle->bn*sizeof(float), &lda, &ldb, &ldc, &alpha, &beta, &updflags, NULL);
605 handle->gemm_upd2.xgemm.smrs = libxsmm_smmdispatch_reducebatch_strd(M, N, handle->bn, handle->desc.K*handle->bn*sizeof(float), handle->desc.C*handle->bn*sizeof(float), &lda, &ldb, &ldc, &alpha, &zerobeta, &updflags, NULL);
606 } else if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) {
607 float alpha = 1.0f;
608 float beta = 1.0f;
609 float zerobeta = 0.0f;
610 /* For UPD kernels we consider subtasking... */
611 libxsmm_blasint M = handle->bk/handle->ofm_subtasks;
612 libxsmm_blasint N = handle->bc/handle->ifm_subtasks;
613
614 libxsmm_blasint lda = (libxsmm_blasint)handle->bk;
615 libxsmm_blasint ldb = (libxsmm_blasint)handle->bc;
616 libxsmm_blasint ldc = (libxsmm_blasint)handle->bk;
617
618 handle->gemm_fwd.xgemm.bsmrs = libxsmm_bsmmdispatch_reducebatch_strd(handle->bk, handle->bn, handle->bc, handle->bk*handle->bc*sizeof(libxsmm_bfloat16), handle->bc*handle->bn*sizeof(libxsmm_bfloat16), &lda, &ldb, &ldc, &alpha, &beta, NULL, NULL);
619 handle->gemm_fwd2.xgemm.bmrs = libxsmm_bmmdispatch_reducebatch_strd(handle->bk, handle->bn, handle->bc, handle->bk*handle->bc*sizeof(libxsmm_bfloat16), handle->bc*handle->bn*sizeof(libxsmm_bfloat16), &lda, &ldb, &ldc, &alpha, &zerobeta, NULL, NULL);
620 handle->gemm_fwd3.xgemm.bmrs = libxsmm_bmmdispatch_reducebatch_strd(handle->bk, handle->bn, handle->bc, handle->bk*handle->bc*sizeof(libxsmm_bfloat16), handle->bc*handle->bn*sizeof(libxsmm_bfloat16), &lda, &ldb, &ldc, &alpha, &beta, NULL, NULL);
621 /* Special bwd kernels for K == 1 */
622 if (handle->desc.K == 1) {
623 libxsmm_blasint _bk = 2;
624 handle->gemm_bwd.xgemm.bsmrs = libxsmm_bsmmdispatch_reducebatch_strd(handle->bc, handle->bn, _bk, _bk*handle->bc*sizeof(libxsmm_bfloat16), _bk*handle->bn*sizeof(libxsmm_bfloat16), &ldb, &_bk, &ldb, &alpha, &beta, NULL, NULL);
625 handle->gemm_bwd2.xgemm.bmrs = libxsmm_bmmdispatch_reducebatch_strd(handle->bc, handle->bn, _bk, _bk*handle->bc*sizeof(libxsmm_bfloat16), _bk*handle->bn*sizeof(libxsmm_bfloat16), &ldb, &_bk, &ldb, &alpha, &zerobeta, NULL, NULL);
626 } else {
627 handle->gemm_bwd.xgemm.bsmrs = libxsmm_bsmmdispatch_reducebatch_strd(handle->bc, handle->bn, handle->bk, handle->bk*handle->bc*sizeof(libxsmm_bfloat16), handle->bk*handle->bn*sizeof(libxsmm_bfloat16), &ldb, &lda, &ldb, &alpha, &beta, NULL, NULL);
628 handle->gemm_bwd2.xgemm.bmrs = libxsmm_bmmdispatch_reducebatch_strd(handle->bc, handle->bn, handle->bk, handle->bk*handle->bc*sizeof(libxsmm_bfloat16), handle->bk*handle->bn*sizeof(libxsmm_bfloat16), &ldb, &lda, &ldb, &alpha, &zerobeta, NULL, NULL);
629 }
630 lda = (libxsmm_blasint)handle->bk;
631 ldb = (libxsmm_blasint)handle->bn;
632 ldc = (libxsmm_blasint)handle->bk;
633 handle->gemm_upd.xgemm.bsmrs = libxsmm_bsmmdispatch_reducebatch_strd(M, N, handle->bn, handle->bk*handle->bn*sizeof(libxsmm_bfloat16), handle->bc*handle->bn*sizeof(libxsmm_bfloat16), &lda, &ldb, &ldc, &alpha, &beta, NULL, NULL);
634 handle->gemm_upd2.xgemm.bmrs = libxsmm_bmmdispatch_reducebatch_strd(M, N, handle->bn, handle->bk*handle->bn*sizeof(libxsmm_bfloat16), handle->bc*handle->bn*sizeof(libxsmm_bfloat16), &lda, &ldb, &ldc, &alpha, &zerobeta, NULL, NULL);
635 } else {
636
637 }
638 }
639 } else {
640 *status = LIBXSMM_DNN_ERR_CREATE_HANDLE;
641 }
642 } else {
643 *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
644 }
645
646 return handle;
647 }
648
649
libxsmm_dnn_destroy_fullyconnected(const libxsmm_dnn_fullyconnected * handle)650 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_fullyconnected(const libxsmm_dnn_fullyconnected* handle) {
651 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
652
653 if (0 != handle) {
654 /* Deallocate barrier */
655 if (handle->barrier != 0 ) { libxsmm_barrier_release((const libxsmm_barrier*)handle->barrier); }
656 /* deallocate handle structure */
657 free(/*remove constness*/(libxsmm_dnn_fullyconnected*)handle);
658 } else {
659 status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
660 }
661
662 return status;
663 }
664
665
libxsmm_dnn_fullyconnected_create_tensor_datalayout(const libxsmm_dnn_fullyconnected * handle,const libxsmm_dnn_tensor_type type,libxsmm_dnn_err_t * status)666 LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_fullyconnected_create_tensor_datalayout(const libxsmm_dnn_fullyconnected* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status) {
667 libxsmm_dnn_tensor_datalayout* layout;
668
669 *status = LIBXSMM_DNN_SUCCESS;
670 layout = 0;
671
672 if (handle != 0) {
673 layout = (libxsmm_dnn_tensor_datalayout*) malloc(sizeof(libxsmm_dnn_tensor_datalayout));
674
675 if (layout != 0) {
676 memset(layout, 0, sizeof(libxsmm_dnn_tensor_datalayout));
677
678 if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) ||
679 (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) {
680 layout->format = handle->desc.buffer_format;
681 if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0) {
682 if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) {
683 layout->datatype = LIBXSMM_DNN_DATATYPE_F32;
684 layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype));
685 layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int));
686
687 if (0 != layout->dim_type && 0 != layout->dim_size) {
688 layout->num_dims = 5;
689 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
690 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W;
691 layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H;
692 layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
693 layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
694 if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) ) {
695 layout->dim_size[0] = handle->ifmblock;
696 layout->dim_size[1] = 1;
697 layout->dim_size[2] = 1;
698 layout->dim_size[3] = handle->blocksifm;
699 layout->dim_size[4] = handle->desc.N;
700 } else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) {
701 layout->dim_size[0] = handle->ofmblock;
702 layout->dim_size[1] = 1;
703 layout->dim_size[2] = 1;
704 layout->dim_size[3] = handle->blocksofm;
705 layout->dim_size[4] = handle->desc.N;
706 } else { /* coverity[dead_error_begin] */
707 free(layout->dim_type);
708 free(layout->dim_size);
709 free(layout);
710 layout = 0; /* make sure a NULL is returned */
711 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
712 }
713 } else {
714 free(layout);
715 layout = 0; /* make sure a NULL is returned */
716 *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
717 }
718 } else if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) {
719 if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) ) {
720 layout->datatype = handle->desc.datatype_in;
721 layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype));
722 layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int));
723 if (0 != layout->dim_type && 0 != layout->dim_size) {
724 layout->num_dims = 5;
725 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
726 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W;
727 layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H;
728 layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
729 layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
730 layout->dim_size[0] = handle->ifmblock;
731 layout->dim_size[1] = 1;
732 layout->dim_size[2] = 1;
733 layout->dim_size[3] = handle->blocksifm;
734 layout->dim_size[4] = handle->desc.N;
735 } else {
736 free(layout->dim_type);
737 free(layout->dim_size);
738 free(layout);
739 layout = 0; /* make sure a NULL is returned */
740 *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
741 }
742 } else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) {
743 layout->datatype = handle->desc.datatype_out;
744 layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype));
745 layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int));
746 if (0 != layout->dim_type && 0 != layout->dim_size) {
747 layout->num_dims = 5;
748 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
749 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W;
750 layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H;
751 layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
752 layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
753 layout->dim_size[0] = handle->ofmblock;
754 layout->dim_size[1] = 1;
755 layout->dim_size[2] = 1;
756 layout->dim_size[3] = handle->blocksofm;
757 layout->dim_size[4] = handle->desc.N;
758 } else {
759 free(layout->dim_type);
760 free(layout->dim_size);
761 free(layout);
762 layout = 0; /* make sure a NULL is returned */
763 *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
764 }
765 } else {
766 free(layout);
767 layout = 0; /* make sure a NULL is returned */
768 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
769 }
770 } else {
771 free(layout);
772 layout = 0; /* make sure a NULL is returned */
773 *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
774 }
775 } else if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NHWC) > 0) {
776 if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) ||
777 ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) ||
778 ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) ) {
779 layout->datatype = handle->desc.datatype_in;
780 layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype));
781 layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int));
782 if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
783 layout->num_dims = 4;
784 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
785 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W;
786 layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H;
787 layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
788 if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) ) {
789 layout->dim_size[0] = handle->desc.C;
790 layout->dim_size[1] = 1;
791 layout->dim_size[2] = 1;
792 layout->dim_size[3] = handle->desc.N;
793 } else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) {
794 layout->dim_size[0] = handle->desc.K;
795 layout->dim_size[1] = 1;
796 layout->dim_size[2] = 1;
797 layout->dim_size[3] = handle->desc.N;
798 } else {
799 free(layout->dim_type);
800 free(layout->dim_size);
801 free(layout);
802 layout = 0; /* make sure a NULL is returned */
803 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
804 }
805 }
806 } else {
807 free(layout);
808 layout = 0; /* make sure a NULL is returned */
809 *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
810 }
811 } else if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) > 0) {
812 if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) ||
813 ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) ) {
814 layout->datatype = handle->desc.datatype_in;
815 layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype));
816 layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int));
817
818 if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
819 layout->num_dims = 4;
820
821 if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) ) {
822 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
823 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
824 layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
825 layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
826 layout->dim_size[0] = (unsigned int)handle->bc;
827 layout->dim_size[1] = (unsigned int)handle->bn;
828 layout->dim_size[2] = (unsigned int)(handle->desc.C / handle->bc);
829 layout->dim_size[3] = (unsigned int)(handle->desc.N / handle->bn);
830 } else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) ) {
831 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
832 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
833 layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
834 layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
835 layout->dim_size[0] = (unsigned int)handle->bk;
836 layout->dim_size[1] = (unsigned int)handle->bn;
837 layout->dim_size[2] = (unsigned int)(handle->desc.K / handle->bk);
838 layout->dim_size[3] = (unsigned int)(handle->desc.N / handle->bn);
839 } else {
840 free(layout->dim_type);
841 free(layout->dim_size);
842 free(layout);
843 layout = 0; /* make sure a NULL is returned */
844 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
845 }
846 } else {
847 free(layout);
848 layout = 0; /* make sure a NULL is returned */
849 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
850 }
851 } else {
852 free(layout);
853 layout = 0; /* make sure a NULL is returned */
854 *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
855 }
856 } else {
857 free(layout);
858 layout = 0; /* make sure a NULL is returned */
859 *status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL;
860 }
861 } else if ( (type == LIBXSMM_DNN_REGULAR_FILTER) || (type == LIBXSMM_DNN_GRADIENT_FILTER) || (type == LIBXSMM_DNN_FILTER) ) {
862 layout->format = handle->desc.filter_format;
863 layout->tensor_type = LIBXSMM_DNN_FILTER;
864
865 if ((handle->desc.filter_format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0) {
866 if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) {
867 layout->datatype = handle->desc.datatype_in;
868 layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(6*sizeof(libxsmm_dnn_tensor_dimtype));
869 layout->dim_size = (unsigned int*) malloc(6*sizeof(unsigned int));
870 if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
871 layout->num_dims = 6;
872 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
873 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
874 layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_S;
875 layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_R;
876 layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
877 layout->dim_type[5] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
878 layout->dim_size[0] = handle->ofmblock;
879 layout->dim_size[1] = handle->ifmblock;
880 layout->dim_size[2] = 1;
881 layout->dim_size[3] = 1;
882 layout->dim_size[4] = handle->blocksifm;
883 layout->dim_size[5] = handle->blocksofm;
884 } else {
885 free(layout);
886 layout = 0; /* make sure a NULL is returned */
887 *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
888 }
889 } else if ( ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) ||
890 ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) ) {
891 layout->datatype = LIBXSMM_DNN_DATATYPE_BF16;
892 layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(7*sizeof(libxsmm_dnn_tensor_dimtype));
893 layout->dim_size = (unsigned int*) malloc(7*sizeof(unsigned int));
894 if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
895 layout->num_dims = 7;
896 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
897 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
898 layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
899 layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_S;
900 layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_R;
901 layout->dim_type[5] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
902 layout->dim_type[6] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
903 layout->dim_size[0] = handle->fm_lp_block;
904 layout->dim_size[1] = handle->ofmblock;
905 layout->dim_size[2] = handle->ifmblock/handle->fm_lp_block;
906 layout->dim_size[3] = 1;
907 layout->dim_size[4] = 1;
908 layout->dim_size[5] = handle->blocksifm;
909 layout->dim_size[6] = handle->blocksofm;
910 } else {
911 free(layout);
912 layout = 0; /* make sure a NULL is returned */
913 *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
914 }
915 } else {
916 free(layout);
917 layout = 0; /* make sure a NULL is returned */
918 *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
919 }
920 } else if ((handle->desc.filter_format & LIBXSMM_DNN_TENSOR_FORMAT_RSCK) > 0) {
921 if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) ||
922 ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) ||
923 ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) ) {
924 layout->datatype = handle->desc.datatype_in;
925 layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype));
926 layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int));
927 if (0 != layout->dim_type && 0 != layout->dim_size) {
928 layout->num_dims = 4;
929 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
930 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
931 layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_S;
932 layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_R;
933 layout->dim_size[0] = handle->ofmblock * handle->blocksofm;
934 layout->dim_size[1] = handle->ifmblock * handle->blocksifm;
935 layout->dim_size[2] = 1;
936 layout->dim_size[3] = 1;
937 } else {
938 free(layout);
939 layout = 0; /* make sure a NULL is returned */
940 *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
941 }
942 } else {
943 free(layout);
944 layout = 0; /* make sure a NULL is returned */
945 *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
946 }
947 } else if ((handle->desc.filter_format & LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED) > 0) {
948 if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) ) {
949 layout->datatype = LIBXSMM_DNN_DATATYPE_F32;
950 layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype));
951 layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int));
952
953 if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
954 layout->num_dims = 4;
955
956 if ( (type == LIBXSMM_DNN_REGULAR_FILTER) || (type == LIBXSMM_DNN_GRADIENT_FILTER) ) {
957 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
958 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
959 layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
960 layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
961 layout->dim_size[0] = (unsigned int)handle->bk;
962 layout->dim_size[1] = (unsigned int)handle->bc;
963 layout->dim_size[2] = (unsigned int)(handle->desc.C / handle->bc);
964 layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk);
965 } else {
966 free(layout->dim_type);
967 free(layout->dim_size);
968 free(layout);
969 layout = 0; /* make sure a NULL is returned */
970 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
971 }
972 } else {
973 free(layout);
974 layout = 0; /* make sure a NULL is returned */
975 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
976 }
977 } else if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) ) {
978 layout->datatype = LIBXSMM_DNN_DATATYPE_BF16;
979 layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype));
980 layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int));
981
982 if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
983 layout->num_dims = 5;
984
985 if ( (type == LIBXSMM_DNN_REGULAR_FILTER) || (type == LIBXSMM_DNN_GRADIENT_FILTER) ) {
986 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
987 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
988 layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
989 layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
990 layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
991 layout->dim_size[0] = (unsigned int)2;
992 layout->dim_size[1] = (unsigned int)handle->bk;
993 layout->dim_size[2] = (unsigned int)handle->bc/2;
994 layout->dim_size[3] = (unsigned int)(handle->desc.C / handle->bc);
995 layout->dim_size[4] = (unsigned int)(handle->desc.K / handle->bk);
996 } else {
997 free(layout->dim_type);
998 free(layout->dim_size);
999 free(layout);
1000 layout = 0; /* make sure a NULL is returned */
1001 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
1002 }
1003 } else {
1004 free(layout);
1005 layout = 0; /* make sure a NULL is returned */
1006 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
1007 }
1008 } else {
1009 free(layout);
1010 layout = 0; /* make sure a NULL is returned */
1011 *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
1012 }
1013 } else {
1014 free(layout);
1015 layout = 0; /* make sure a NULL is returned */
1016 *status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL;
1017 }
1018 } else if ( (type == LIBXSMM_DNN_REGULAR_CHANNEL_BIAS) || (type == LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS) || (type == LIBXSMM_DNN_CHANNEL_BIAS) ) {
1019 layout->format = handle->desc.buffer_format;
1020 layout->tensor_type = LIBXSMM_DNN_CHANNEL_SCALAR;
1021
1022 if ( ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) > 0) ) {
1023 if ( (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) || (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) {
1024 layout->datatype = handle->desc.datatype_out;
1025 layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(2*sizeof(libxsmm_dnn_tensor_dimtype));
1026 layout->dim_size = (unsigned int*) malloc(2*sizeof(unsigned int));
1027
1028 if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
1029 layout->num_dims = 2;
1030 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
1031 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
1032 layout->dim_size[0] = (unsigned int)handle->bk;
1033 layout->dim_size[1] = (unsigned int)(handle->desc.K / handle->bk);
1034 } else {
1035 free(layout->dim_type);
1036 free(layout->dim_size);
1037 free(layout);
1038 layout = 0; /* make sure a NULL is returned */
1039 *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
1040 }
1041 }
1042 } else {
1043 free(layout);
1044 layout = 0; /* make sure a NULL is returned */
1045 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
1046 }
1047 } else if ( (type == LIBXSMM_DNN_RELU_MASK) ) {
1048 layout->format = handle->desc.buffer_format;
1049 layout->tensor_type = LIBXSMM_DNN_RELU_MASK;
1050
1051 if ( ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) > 0) ) {
1052 layout->datatype = LIBXSMM_DNN_DATATYPE_I8;
1053 layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(1*sizeof(libxsmm_dnn_tensor_dimtype));
1054 layout->dim_size = (unsigned int*) malloc(1*sizeof(unsigned int));
1055
1056 if (0 != layout->dim_type && 0 != layout->dim_size) {
1057 layout->num_dims = 1;
1058 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_X;
1059 layout->dim_size[0] = handle->desc.N * handle->desc.K;
1060 } else {
1061 free(layout->dim_type);
1062 free(layout->dim_size);
1063 free(layout);
1064 layout = 0; /* make sure a NULL is returned */
1065 *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
1066 }
1067 } else {
1068 free(layout);
1069 layout = 0; /* make sure a NULL is returned */
1070 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
1071 }
1072 } else {
1073 free(layout);
1074 layout = 0; /* make sure a NULL is returned */
1075 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
1076 }
1077 } else {
1078 *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT;
1079 }
1080 }
1081 else {
1082 *status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
1083 }
1084
1085 return layout;
1086 }
1087
libxsmm_dnn_fullyconnected_get_scratch_size(const libxsmm_dnn_fullyconnected * handle,libxsmm_dnn_err_t * status)1088 LIBXSMM_API size_t libxsmm_dnn_fullyconnected_get_scratch_size(const libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_err_t* status) {
1089 size_t l_scratch_size = 0;
1090 *status = LIBXSMM_DNN_SUCCESS;
1091
1092 if (0 != handle) {
1093 l_scratch_size = handle->scratch_size + 64; /* 64 byte extra in case the user code does not care about alignment */
1094 } else {
1095 *status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
1096 }
1097
1098 return l_scratch_size;
1099 }
1100
1101
libxsmm_dnn_fullyconnected_get_scratch_ptr(const libxsmm_dnn_fullyconnected * handle,libxsmm_dnn_err_t * status)1102 LIBXSMM_API void* libxsmm_dnn_fullyconnected_get_scratch_ptr(const libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_err_t* status)
1103 {
1104 *status = LIBXSMM_DNN_SUCCESS;
1105
1106 if (0 != handle) {
1107 return handle->scratch;
1108 } else {
1109 *status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
1110 }
1111
1112 return 0;
1113 }
1114
1115
libxsmm_dnn_fullyconnected_bind_scratch(libxsmm_dnn_fullyconnected * handle,const void * scratch)1116 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_bind_scratch(libxsmm_dnn_fullyconnected* handle, const void* scratch) {
1117 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
1118 uintptr_t address = (uintptr_t)scratch;
1119 size_t offset = 0;
1120
1121 if (scratch == 0) {
1122 status = LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED;
1123 return status;
1124 }
1125
1126 if (0 != handle) {
1127 /* align the internal scratch buffer if needed */
1128 if (address % 64 == 0) {
1129 handle->scratch = (void*)address;
1130 } else {
1131 offset = (64 - address % 64);
1132 handle->scratch = (void*)(address+offset);
1133 }
1134 } else {
1135 status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
1136 }
1137
1138 return status;
1139 }
1140
1141
libxsmm_dnn_fullyconnected_release_scratch(libxsmm_dnn_fullyconnected * handle)1142 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_release_scratch(libxsmm_dnn_fullyconnected* handle) {
1143 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
1144
1145 if (0 != handle) {
1146 handle->scratch = 0;
1147 } else {
1148 status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
1149 }
1150
1151 return status;
1152 }
1153
1154
libxsmm_dnn_fullyconnected_bind_tensor(libxsmm_dnn_fullyconnected * handle,const libxsmm_dnn_tensor * tensor,const libxsmm_dnn_tensor_type type)1155 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_bind_tensor(libxsmm_dnn_fullyconnected* handle, const libxsmm_dnn_tensor* tensor, const libxsmm_dnn_tensor_type type) {
1156 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
1157
1158 /* check for tensor type */
1159 if ( (type != LIBXSMM_DNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_GRADIENT_INPUT) &&
1160 (type != LIBXSMM_DNN_REGULAR_OUTPUT) && (type != LIBXSMM_DNN_GRADIENT_OUTPUT) &&
1161 (type != LIBXSMM_DNN_REGULAR_FILTER) && (type != LIBXSMM_DNN_GRADIENT_FILTER) &&
1162 (type != LIBXSMM_DNN_REGULAR_CHANNEL_BIAS) && (type != LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS) &&
1163 (type != LIBXSMM_DNN_RELU_MASK) ) {
1164 status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
1165 return status;
1166 }
1167
1168 if (handle != 0 && tensor != 0) {
1169 libxsmm_dnn_tensor_datalayout* handle_layout = libxsmm_dnn_fullyconnected_create_tensor_datalayout(handle, type, &status);
1170
1171 if ( libxsmm_dnn_compare_tensor_datalayout(handle_layout, tensor->layout, &status) == 0 ) {
1172 if ( type == LIBXSMM_DNN_REGULAR_INPUT ) {
1173 handle->reg_input = (libxsmm_dnn_tensor*)tensor;
1174 } else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) {
1175 handle->grad_input = (libxsmm_dnn_tensor*)tensor;
1176 } else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) {
1177 handle->reg_output = (libxsmm_dnn_tensor*)tensor;
1178 } else if ( type == LIBXSMM_DNN_GRADIENT_OUTPUT ) {
1179 handle->grad_output = (libxsmm_dnn_tensor*)tensor;
1180 } else if ( type == LIBXSMM_DNN_REGULAR_FILTER ) {
1181 handle->reg_filter = (libxsmm_dnn_tensor*)tensor;
1182 } else if ( type == LIBXSMM_DNN_GRADIENT_FILTER ) {
1183 handle->grad_filter = (libxsmm_dnn_tensor*)tensor;
1184 } else if ( type == LIBXSMM_DNN_REGULAR_CHANNEL_BIAS ) {
1185 handle->reg_bias = (libxsmm_dnn_tensor*)tensor;
1186 } else if ( type == LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS ) {
1187 handle->grad_bias = (libxsmm_dnn_tensor*)tensor;
1188 } else if ( type == LIBXSMM_DNN_RELU_MASK ) {
1189 handle->relumask = (libxsmm_dnn_tensor*)tensor;
1190 } else {
1191 /* cannot happen */
1192 }
1193 } else {
1194 status = LIBXSMM_DNN_ERR_MISMATCH_TENSOR;
1195 }
1196
1197 libxsmm_dnn_destroy_tensor_datalayout( handle_layout );
1198 }
1199 else {
1200 status = LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR;
1201 }
1202
1203 return status;
1204 }
1205
1206
libxsmm_dnn_fullyconnected_get_tensor(libxsmm_dnn_fullyconnected * handle,const libxsmm_dnn_tensor_type type,libxsmm_dnn_err_t * status)1207 LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_fullyconnected_get_tensor(libxsmm_dnn_fullyconnected* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status) {
1208 libxsmm_dnn_tensor* return_tensor = 0;
1209
1210 *status = LIBXSMM_DNN_SUCCESS;
1211
1212 /* check for tensor type */
1213 if ( (type != LIBXSMM_DNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_GRADIENT_INPUT) &&
1214 (type != LIBXSMM_DNN_REGULAR_OUTPUT) && (type != LIBXSMM_DNN_GRADIENT_OUTPUT) &&
1215 (type != LIBXSMM_DNN_REGULAR_FILTER) && (type != LIBXSMM_DNN_GRADIENT_FILTER) &&
1216 (type != LIBXSMM_DNN_REGULAR_CHANNEL_BIAS) && (type != LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS) &&
1217 (type != LIBXSMM_DNN_RELU_MASK) ) {
1218 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
1219 return return_tensor;
1220 }
1221
1222 if (handle != 0) {
1223 if ( type == LIBXSMM_DNN_REGULAR_INPUT ) {
1224 return_tensor = handle->reg_input;
1225 } else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) {
1226 return_tensor = handle->grad_input;
1227 } else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) {
1228 return_tensor = handle->reg_output;
1229 } else if ( type == LIBXSMM_DNN_GRADIENT_OUTPUT ) {
1230 return_tensor = handle->grad_output;
1231 } else if ( type == LIBXSMM_DNN_REGULAR_FILTER ) {
1232 return_tensor = handle->reg_filter;
1233 } else if ( type == LIBXSMM_DNN_GRADIENT_FILTER ) {
1234 return_tensor = handle->grad_filter;
1235 } else if ( type == LIBXSMM_DNN_REGULAR_CHANNEL_BIAS ) {
1236 return_tensor = handle->reg_bias;
1237 } else if ( type == LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS ) {
1238 return_tensor = handle->grad_bias;
1239 } else if ( type == LIBXSMM_DNN_RELU_MASK ) {
1240 return_tensor = handle->relumask;
1241 } else {
1242 /* cannot happen */
1243 }
1244 } else {
1245 *status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
1246 }
1247
1248 return return_tensor;
1249 }
1250
1251
libxsmm_dnn_fullyconnected_release_tensor(libxsmm_dnn_fullyconnected * handle,const libxsmm_dnn_tensor_type type)1252 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_release_tensor(libxsmm_dnn_fullyconnected* handle, const libxsmm_dnn_tensor_type type) {
1253 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
1254
1255 /* check for tensor type */
1256 if ( (type != LIBXSMM_DNN_REGULAR_INPUT) && (type != LIBXSMM_DNN_GRADIENT_INPUT) &&
1257 (type != LIBXSMM_DNN_REGULAR_OUTPUT) && (type != LIBXSMM_DNN_GRADIENT_OUTPUT) &&
1258 (type != LIBXSMM_DNN_REGULAR_FILTER) && (type != LIBXSMM_DNN_GRADIENT_FILTER) &&
1259 (type != LIBXSMM_DNN_REGULAR_CHANNEL_BIAS) && (type != LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS) &&
1260 (type != LIBXSMM_DNN_RELU_MASK) ) {
1261 status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
1262 return status;
1263 }
1264
1265 if (handle != 0) {
1266 if ( type == LIBXSMM_DNN_REGULAR_INPUT ) {
1267 handle->reg_input = 0;
1268 } else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) {
1269 handle->grad_input = 0;
1270 } else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) {
1271 handle->reg_output = 0;
1272 } else if ( type == LIBXSMM_DNN_GRADIENT_OUTPUT ) {
1273 handle->grad_output = 0;
1274 } else if ( type == LIBXSMM_DNN_REGULAR_FILTER ) {
1275 handle->reg_filter = 0;
1276 } else if ( type == LIBXSMM_DNN_GRADIENT_FILTER ) {
1277 handle->grad_filter = 0;
1278 } else if ( type == LIBXSMM_DNN_REGULAR_CHANNEL_BIAS ) {
1279 handle->reg_bias = 0;
1280 } else if ( type == LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS ) {
1281 handle->grad_bias = 0;
1282 } else if ( type == LIBXSMM_DNN_RELU_MASK ) {
1283 handle->relumask = 0;
1284 } else {
1285 /* cannot happen */
1286 }
1287 } else {
1288 status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
1289 }
1290
1291 return status;
1292 }
1293
1294
libxsmm_dnn_fullyconnected_execute_st(libxsmm_dnn_fullyconnected * handle,libxsmm_dnn_compute_kind kind,int start_thread,int tid)1295 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_execute_st(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind,
1296 /*unsigned*/int start_thread, /*unsigned*/int tid) {
1297 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
1298 LIBXSMM_UNUSED( start_thread );
1299 LIBXSMM_UNUSED( tid );
1300
1301 if (0 != handle) {
1302 switch (kind) {
1303 case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
1304 if ( (handle->desc.buffer_format == LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) && (handle->desc.filter_format == LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) ) {
1305 status = libxsmm_dnn_fullyconnected_st_fwd_custom( handle, start_thread, tid );
1306 } else if ( (handle->desc.buffer_format == LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) && (handle->desc.filter_format == LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED) ) {
1307 status = libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck( handle, start_thread, tid );
1308 } else {
1309 status = LIBXSMM_DNN_ERR_INVALID_FORMAT_FC;
1310 }
1311 } break;
1312 case LIBXSMM_DNN_COMPUTE_KIND_BWD:
1313 case LIBXSMM_DNN_COMPUTE_KIND_UPD:
1314 case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD: {
1315 if ( (handle->desc.buffer_format == LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) && (handle->desc.filter_format == LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) ) {
1316 status = libxsmm_dnn_fullyconnected_st_bwdupd_custom( handle, kind, start_thread, tid );
1317 } else if ( (handle->desc.buffer_format == LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) && (handle->desc.filter_format == LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED) ) {
1318 status = libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck( handle, kind, start_thread, tid );
1319 } else {
1320 status = LIBXSMM_DNN_ERR_INVALID_FORMAT_FC;
1321 }
1322 } break;
1323 default: {
1324 status = LIBXSMM_DNN_ERR_INVALID_KIND;
1325 }
1326 }
1327 }
1328 else {
1329 status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
1330 }
1331
1332 return status;
1333 }
1334
1335