1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Alexander Heinecke, Sasikanth Avancha (Intel Corp.)
10 ******************************************************************************/
11 #include "libxsmm_dnn_fullyconnected_backward_weight_update.h"
12 #include "libxsmm_dnn_fullyconnected_forward.h"
13 #include "libxsmm_main.h"
14 
libxsmm_dnn_create_fullyconnected(libxsmm_dnn_fullyconnected_desc fullyconnected_desc,libxsmm_dnn_err_t * status)15 LIBXSMM_API libxsmm_dnn_fullyconnected* libxsmm_dnn_create_fullyconnected(libxsmm_dnn_fullyconnected_desc fullyconnected_desc, libxsmm_dnn_err_t* status) {
16   libxsmm_dnn_fullyconnected* handle = 0;
17   const libxsmm_trans_descriptor* tr_desc = 0;
18   libxsmm_descriptor_blob blob;
19 
20   /* init libxsmm */
21   LIBXSMM_INIT
22 
23   if ( ((fullyconnected_desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (fullyconnected_desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)) ||
24        ((fullyconnected_desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32)  && (fullyconnected_desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32))  ||
25        ((fullyconnected_desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (fullyconnected_desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32))     ) {
26     handle = (libxsmm_dnn_fullyconnected*)malloc(sizeof(libxsmm_dnn_fullyconnected));
27 
28     if (0 != handle) {
29       *status = LIBXSMM_DNN_SUCCESS;
30       /* zero entire content; not only safer but also sets data and code pointers to NULL */
31       memset(handle, 0, sizeof(*handle));
32       /* let's make the description persistent */
33       handle->desc = fullyconnected_desc;
34       /* @TODO perhaps we need a better switch here */
35       if ( (handle->desc.buffer_format == LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) && (handle->desc.filter_format == LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED) ) {
36         handle->bk = handle->desc.bk;
37         handle->bn = handle->desc.bn;
38         handle->bc = handle->desc.bc;
39 
40         if ( handle->desc.N % handle->bn != 0 ) {
41           handle->bn = handle->desc.N;
42           *status = LIBXSMM_DNN_WARN_FC_SUBOPTIMAL_N_BLOCKING;
43         }
44         if ( handle->desc.C % handle->bc != 0 ) {
45           handle->bc = handle->desc.C;
46           *status = LIBXSMM_DNN_WARN_FC_SUBOPTIMAL_C_BLOCKING;
47         }
48         if ( handle->desc.K % handle->bk != 0 ) {
49           handle->bk = handle->desc.K;
50           *status = LIBXSMM_DNN_WARN_FC_SUBOPTIMAL_K_BLOCKING;
51         }
52         if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) )  {
53 #if 0
54           handle->fwd_bf = atoi(getenv("FWD_BF"));
55           handle->bwd_bf = atoi(getenv("BWD_BF"));
56           handle->upd_bf = atoi(getenv("UPD_BF"));
57           handle->fwd_2d_blocking = atoi(getenv("FWD_2D_BLOCKING"));
58           handle->bwd_2d_blocking = atoi(getenv("BWD_2D_BLOCKING"));
59           handle->upd_2d_blocking = atoi(getenv("UPD_2D_BLOCKING"));
60           handle->fwd_row_teams = atoi(getenv("FWD_ROW_TEAMS"));
61           handle->fwd_column_teams = atoi(getenv("FWD_COLUMN_TEAMS"));
62           handle->bwd_row_teams = atoi(getenv("BWD_ROW_TEAMS"));
63           handle->bwd_column_teams = atoi(getenv("BWD_COLUMN_TEAMS"));
64           handle->upd_row_teams = atoi(getenv("UPD_ROW_TEAMS"));
65           handle->upd_column_teams = atoi(getenv("UPD_COLUMN_TEAMS"));
66           handle->ifm_subtasks = atoi(getenv("IFM_SUBTASKS"));
67           handle->ofm_subtasks = atoi(getenv("OFM_SUBTASKS"));
68 #else
69           /* Initialize with default values */
70           handle->fwd_bf = 1;
71           handle->bwd_bf = 1;
72           handle->upd_bf = 1;
73           handle->fwd_2d_blocking = 0;
74           handle->bwd_2d_blocking = 0;
75           handle->upd_2d_blocking = 0;
76           handle->fwd_row_teams = 1;
77           handle->fwd_column_teams = 1;
78           handle->bwd_row_teams = 1;
79           handle->bwd_column_teams = 1;
80           handle->upd_row_teams = 1;
81           handle->upd_column_teams = 1;
82           handle->ifm_subtasks = 1;
83           handle->ofm_subtasks = 1;
84 
85           if (handle->desc.C == 100 && handle->desc.K == 1024 && handle->desc.threads == 28) {
86             handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
87             handle->fwd_2d_blocking = 1;
88             handle->fwd_row_teams = 14;
89             handle->fwd_column_teams = 2;
90             handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
91             handle->bwd_2d_blocking = 0;
92             handle->bwd_row_teams = 1;
93             handle->bwd_column_teams = 1;
94             handle->upd_bf = ((handle->desc.N/handle->bn) % 14 == 0) ? 14 : 1;
95             handle->upd_2d_blocking = 0;
96             handle->upd_row_teams = 1;
97             handle->upd_column_teams = 1;
98             handle->ifm_subtasks = 1/*((handle->bc % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
99             handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
100           }
101 
102           if (handle->desc.C == 1024 && handle->desc.K == 1024 && handle->desc.threads == 28) {
103             handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
104             handle->fwd_2d_blocking = 1;
105             handle->fwd_row_teams = 7;
106             handle->fwd_column_teams = 4;
107             handle->bwd_bf = ((handle->desc.K/handle->bk) % 8 == 0) ? 8 : 1;
108             handle->bwd_2d_blocking = 0;
109             handle->bwd_row_teams = 7;
110             handle->bwd_column_teams = 4;
111             handle->upd_bf = ((handle->desc.N/handle->bn) % 14 == 0) ? 14 : 1;
112             handle->upd_2d_blocking = 0;
113             handle->upd_row_teams = 7;
114             handle->upd_column_teams = 4;
115             handle->ifm_subtasks = ((handle->bc % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1;
116             handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
117           }
118 
119           if (handle->desc.C == 512 && handle->desc.K == 512 && handle->desc.threads == 28) {
120             handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
121             handle->fwd_2d_blocking = 0;
122             handle->fwd_row_teams = 1;
123             handle->fwd_column_teams = 1;
124             handle->bwd_bf = ((handle->desc.K/handle->bk) % 4 == 0) ? 4 : 1;
125             handle->bwd_2d_blocking = 0;
126             handle->bwd_row_teams = 1;
127             handle->bwd_column_teams = 1;
128             handle->upd_bf = ((handle->desc.N/handle->bn) % 14 == 0) ? 14 : 1;
129             handle->upd_2d_blocking = 0;
130             handle->upd_row_teams = 1;
131             handle->upd_column_teams = 1;
132             handle->ifm_subtasks = ((handle->bc % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1;
133             handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
134           }
135 
136           if (handle->desc.C == 1024 && handle->desc.K == 1 && handle->desc.threads == 28) {
137             handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
138             handle->fwd_2d_blocking = 0;
139             handle->fwd_row_teams = 1;
140             handle->fwd_column_teams = 1;
141             handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
142             handle->bwd_2d_blocking = 1;
143             handle->bwd_row_teams = 14;
144             handle->bwd_column_teams = 2;
145             handle->upd_bf = ((handle->desc.N/handle->bn) % 2 == 0) ? 2 : 1;
146             handle->upd_2d_blocking = 0;
147             handle->upd_row_teams = 1;
148             handle->upd_column_teams = 1;
149             handle->ifm_subtasks = ((handle->bc % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1;
150             handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
151           }
152 
153           if (handle->desc.C == 1024 && handle->desc.K == 1024 && handle->desc.threads == 20) {
154             handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
155             handle->fwd_2d_blocking = 0;
156             handle->fwd_row_teams = 5;
157             handle->fwd_column_teams = 4;
158             handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
159             handle->bwd_2d_blocking = 1;
160             handle->bwd_row_teams = 5;
161             handle->bwd_column_teams = 4;
162             handle->upd_bf = ((handle->desc.N/handle->bn) % 15 == 0) ? 15 : 1;
163             handle->upd_2d_blocking = 0;
164             handle->upd_row_teams = 5;
165             handle->upd_column_teams = 4;
166             handle->ifm_subtasks = 1/*((handle->bc % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
167             handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
168           }
169 
170           if (handle->desc.C == 100 && handle->desc.K == 1024 && handle->desc.threads == 20) {
171             handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
172             handle->fwd_2d_blocking = 1;
173             handle->fwd_row_teams = 5;
174             handle->fwd_column_teams = 4;
175             handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
176             handle->bwd_2d_blocking = 0;
177             handle->bwd_row_teams = 1;
178             handle->bwd_column_teams = 1;
179             handle->upd_bf = ((handle->desc.N/handle->bn) % 9 == 0) ? 9 : 1;
180             handle->upd_2d_blocking = 0;
181             handle->upd_row_teams = 1;
182             handle->upd_column_teams = 1;
183             handle->ifm_subtasks = 1/*((handle->bc % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
184             handle->ofm_subtasks = ((handle->bk % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1;
185           }
186 
187           if (handle->desc.C == 1024 && handle->desc.K == 1024 && handle->desc.threads == 24) {
188             handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
189             handle->fwd_2d_blocking = 0;
190             handle->fwd_row_teams = 6;
191             handle->fwd_column_teams = 4;
192             handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
193             handle->bwd_2d_blocking = 0;
194             handle->bwd_row_teams = 6;
195             handle->bwd_column_teams = 4;
196             handle->upd_bf = ((handle->desc.N/handle->bn) % 15 == 0) ? 15 : 1;
197             handle->upd_2d_blocking = 0;
198             handle->upd_row_teams = 6;
199             handle->upd_column_teams = 4;
200             handle->ifm_subtasks = ((handle->bc % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1;
201             handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
202           }
203           if (handle->desc.C == 100 && handle->desc.K == 1024 && handle->desc.threads == 24) {
204             handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
205             handle->fwd_2d_blocking = 0;
206             handle->fwd_row_teams = 5;
207             handle->fwd_column_teams = 4;
208             handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
209             handle->bwd_2d_blocking = 1;
210             handle->bwd_row_teams = 12;
211             handle->bwd_column_teams = 2;
212             handle->upd_bf = ((handle->desc.N/handle->bn) % 15 == 0) ? 15 : 1;
213             handle->upd_2d_blocking = 0;
214             handle->upd_row_teams = 5;
215             handle->upd_column_teams = 4;
216             handle->ifm_subtasks = 1/*((handle->bc % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
217             handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
218           }
219           if (handle->desc.C == 512 && handle->desc.K == 512 && handle->desc.threads == 24) {
220             handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
221             handle->fwd_2d_blocking = 0;
222             handle->fwd_row_teams = 5;
223             handle->fwd_column_teams = 4;
224             handle->bwd_bf = ((handle->desc.K/handle->bk) % 4 == 0) ? 4 : 1;
225             handle->bwd_2d_blocking = 0;
226             handle->bwd_row_teams = 5;
227             handle->bwd_column_teams = 4;
228             handle->upd_bf = ((handle->desc.N/handle->bn) % 15 == 0) ? 15 : 1;
229             handle->upd_2d_blocking = 0;
230             handle->upd_row_teams = 5;
231             handle->upd_column_teams = 4;
232             handle->ifm_subtasks = ((handle->bc % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1;
233             handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
234           }
235           if (handle->desc.C == 512 && handle->desc.K == 512 && handle->desc.threads == 20) {
236             handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
237             handle->fwd_2d_blocking = 1;
238             handle->fwd_row_teams = 5;
239             handle->fwd_column_teams = 4;
240             handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
241             handle->bwd_2d_blocking = 0;
242             handle->bwd_row_teams = 1;
243             handle->bwd_column_teams = 1;
244             handle->upd_bf = ((handle->desc.N/handle->bn) % 15 == 0) ? 15 : 1;
245             handle->upd_2d_blocking = 0;
246             handle->upd_row_teams = 1;
247             handle->upd_column_teams = 1;
248             handle->ifm_subtasks = ((handle->bc % 4 == 0) && (handle->upd_2d_blocking == 0)) ? 4 : 1;
249             handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
250           }
251           if (handle->desc.C == 1024 && handle->desc.K == 1 && handle->desc.threads == 24) {
252             handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
253             handle->fwd_2d_blocking = 0;
254             handle->fwd_row_teams = 5;
255             handle->fwd_column_teams = 4;
256             handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
257             handle->bwd_2d_blocking = 0;
258             handle->bwd_row_teams = 5;
259             handle->bwd_column_teams = 4;
260             handle->upd_bf = 1/*((handle->desc.N/handle->bn) % 1 == 0) ? 1 : 1*/;
261             handle->upd_2d_blocking = 0;
262             handle->upd_row_teams = 5;
263             handle->upd_column_teams = 4;
264             handle->ifm_subtasks = ((handle->bc % 4 == 0) && (handle->upd_2d_blocking == 0)) ? 4 : 1;
265             handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
266           }
267           if (handle->desc.C == 1024 && handle->desc.K == 1 && handle->desc.threads == 20) {
268             handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
269             handle->fwd_2d_blocking = 0;
270             handle->fwd_row_teams = 6;
271             handle->fwd_column_teams = 4;
272             handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
273             handle->bwd_2d_blocking = 1;
274             handle->bwd_row_teams = 5;
275             handle->bwd_column_teams = 4;
276             handle->upd_bf = 1/*((handle->desc.N/handle->bn) % 1 == 0) ? 1 : 1*/;
277             handle->upd_2d_blocking = 0;
278             handle->upd_row_teams = 6;
279             handle->upd_column_teams = 4;
280             handle->ifm_subtasks = 1/*((handle->bc % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
281             handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
282           }
283 #endif
284         } else if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) )  {
285 #if 0
286           handle->fwd_bf = atoi(getenv("FWD_BF"));
287           handle->bwd_bf = atoi(getenv("BWD_BF"));
288           handle->upd_bf = atoi(getenv("UPD_BF"));
289           handle->fwd_2d_blocking = atoi(getenv("FWD_2D_BLOCKING"));
290           handle->bwd_2d_blocking = atoi(getenv("BWD_2D_BLOCKING"));
291           handle->upd_2d_blocking = atoi(getenv("UPD_2D_BLOCKING"));
292           handle->fwd_row_teams = atoi(getenv("FWD_ROW_TEAMS"));
293           handle->fwd_column_teams = atoi(getenv("FWD_COLUMN_TEAMS"));
294           handle->bwd_row_teams = atoi(getenv("BWD_ROW_TEAMS"));
295           handle->bwd_column_teams = atoi(getenv("BWD_COLUMN_TEAMS"));
296           handle->upd_row_teams = atoi(getenv("UPD_ROW_TEAMS"));
297           handle->upd_column_teams = atoi(getenv("UPD_COLUMN_TEAMS"));
298           handle->ifm_subtasks = atoi(getenv("IFM_SUBTASKS"));
299           handle->ofm_subtasks = atoi(getenv("OFM_SUBTASKS"));
300 #else
301           /* Initialize with default values */
302           handle->fwd_bf = 1;
303           handle->bwd_bf = 1;
304           handle->upd_bf = 1;
305           handle->fwd_2d_blocking = 0;
306           handle->bwd_2d_blocking = 0;
307           handle->upd_2d_blocking = 0;
308           handle->fwd_row_teams = 1;
309           handle->fwd_column_teams = 1;
310           handle->bwd_row_teams = 1;
311           handle->bwd_column_teams = 1;
312           handle->upd_row_teams = 1;
313           handle->upd_column_teams = 1;
314           handle->ifm_subtasks = 1;
315           handle->ofm_subtasks = 1;
316 
317           if (handle->desc.C == 100 && handle->desc.K == 1024 && handle->desc.threads == 28) {
318             handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
319             handle->fwd_2d_blocking = 1;
320             handle->fwd_row_teams = 14;
321             handle->fwd_column_teams = 2;
322             handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
323             handle->bwd_2d_blocking = 0;
324             handle->bwd_row_teams = 1;
325             handle->bwd_column_teams = 1;
326             handle->upd_bf = ((handle->desc.N/handle->bn) % 14 == 0) ? 14 : 1;
327             handle->upd_2d_blocking = 0;
328             handle->upd_row_teams = 1;
329             handle->upd_column_teams = 1;
330             handle->ifm_subtasks = 1/*((handle->bc % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
331             handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
332           }
333 
334           if (handle->desc.C == 1024 && handle->desc.K == 1024 && handle->desc.threads == 28) {
335             handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
336             handle->fwd_2d_blocking = 1;
337             handle->fwd_row_teams = 7;
338             handle->fwd_column_teams = 4;
339             handle->bwd_bf = ((handle->desc.K/handle->bk) % 8 == 0) ? 8 : 1;
340             handle->bwd_2d_blocking = 0;
341             handle->bwd_row_teams = 7;
342             handle->bwd_column_teams = 4;
343             handle->upd_bf = ((handle->desc.N/handle->bn) % 14 == 0) ? 14 : 1;
344             handle->upd_2d_blocking = 0;
345             handle->upd_row_teams = 7;
346             handle->upd_column_teams = 4;
347             handle->ifm_subtasks = ((handle->bc % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1;
348             handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
349           }
350 
351           if (handle->desc.C == 512 && handle->desc.K == 512 && handle->desc.threads == 28) {
352             handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
353             handle->fwd_2d_blocking = 0;
354             handle->fwd_row_teams = 1;
355             handle->fwd_column_teams = 1;
356             handle->bwd_bf = ((handle->desc.K/handle->bk) % 4 == 0) ? 4 : 1;
357             handle->bwd_2d_blocking = 0;
358             handle->bwd_row_teams = 1;
359             handle->bwd_column_teams = 1;
360             handle->upd_bf = ((handle->desc.N/handle->bn) % 14 == 0) ? 14 : 1;
361             handle->upd_2d_blocking = 0;
362             handle->upd_row_teams = 1;
363             handle->upd_column_teams = 1;
364             handle->ifm_subtasks = ((handle->bc % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1;
365             handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
366           }
367 
368           if (handle->desc.C == 1024 && handle->desc.K == 1 && handle->desc.threads == 28) {
369             handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
370             handle->fwd_2d_blocking = 0;
371             handle->fwd_row_teams = 1;
372             handle->fwd_column_teams = 1;
373             handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
374             handle->bwd_2d_blocking = 1;
375             handle->bwd_row_teams = 14;
376             handle->bwd_column_teams = 2;
377             handle->upd_bf = ((handle->desc.N/handle->bn) % 2 == 0) ? 2 : 1;
378             handle->upd_2d_blocking = 0;
379             handle->upd_row_teams = 1;
380             handle->upd_column_teams = 1;
381             handle->ifm_subtasks = ((handle->bc % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1;
382             handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
383           }
384 
385           if (handle->desc.C == 1024 && handle->desc.K == 1024 && handle->desc.threads == 20) {
386             handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
387             handle->fwd_2d_blocking = 0;
388             handle->fwd_row_teams = 5;
389             handle->fwd_column_teams = 4;
390             handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
391             handle->bwd_2d_blocking = 1;
392             handle->bwd_row_teams = 5;
393             handle->bwd_column_teams = 4;
394             handle->upd_bf = ((handle->desc.N/handle->bn) % 15 == 0) ? 15 : 1;
395             handle->upd_2d_blocking = 0;
396             handle->upd_row_teams = 5;
397             handle->upd_column_teams = 4;
398             handle->ifm_subtasks = 1/*((handle->bc % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
399             handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
400           }
401 
402           if (handle->desc.C == 100 && handle->desc.K == 1024 && handle->desc.threads == 20) {
403             handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
404             handle->fwd_2d_blocking = 1;
405             handle->fwd_row_teams = 5;
406             handle->fwd_column_teams = 4;
407             handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
408             handle->bwd_2d_blocking = 0;
409             handle->bwd_row_teams = 1;
410             handle->bwd_column_teams = 1;
411             handle->upd_bf = ((handle->desc.N/handle->bn) % 9 == 0) ? 9 : 1;
412             handle->upd_2d_blocking = 0;
413             handle->upd_row_teams = 1;
414             handle->upd_column_teams = 1;
415             handle->ifm_subtasks = 1/*((handle->bc % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
416             handle->ofm_subtasks = ((handle->bk % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1;
417           }
418 
419           if (handle->desc.C == 1024 && handle->desc.K == 1024 && handle->desc.threads == 24) {
420             handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
421             handle->fwd_2d_blocking = 0;
422             handle->fwd_row_teams = 6;
423             handle->fwd_column_teams = 4;
424             handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
425             handle->bwd_2d_blocking = 0;
426             handle->bwd_row_teams = 6;
427             handle->bwd_column_teams = 4;
428             handle->upd_bf = ((handle->desc.N/handle->bn) % 15 == 0) ? 15 : 1;
429             handle->upd_2d_blocking = 0;
430             handle->upd_row_teams = 6;
431             handle->upd_column_teams = 4;
432             handle->ifm_subtasks = ((handle->bc % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1;
433             handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
434           }
435           if (handle->desc.C == 100 && handle->desc.K == 1024 && handle->desc.threads == 24) {
436             handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
437             handle->fwd_2d_blocking = 0;
438             handle->fwd_row_teams = 5;
439             handle->fwd_column_teams = 4;
440             handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
441             handle->bwd_2d_blocking = 1;
442             handle->bwd_row_teams = 12;
443             handle->bwd_column_teams = 2;
444             handle->upd_bf = ((handle->desc.N/handle->bn) % 15 == 0) ? 15 : 1;
445             handle->upd_2d_blocking = 0;
446             handle->upd_row_teams = 5;
447             handle->upd_column_teams = 4;
448             handle->ifm_subtasks = 1/*((handle->bc % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
449             handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
450           }
451           if (handle->desc.C == 512 && handle->desc.K == 512 && handle->desc.threads == 24) {
452             handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
453             handle->fwd_2d_blocking = 0;
454             handle->fwd_row_teams = 5;
455             handle->fwd_column_teams = 4;
456             handle->bwd_bf = ((handle->desc.K/handle->bk) % 4 == 0) ? 4 : 1;
457             handle->bwd_2d_blocking = 0;
458             handle->bwd_row_teams = 5;
459             handle->bwd_column_teams = 4;
460             handle->upd_bf = ((handle->desc.N/handle->bn) % 15 == 0) ? 15 : 1;
461             handle->upd_2d_blocking = 0;
462             handle->upd_row_teams = 5;
463             handle->upd_column_teams = 4;
464             handle->ifm_subtasks = ((handle->bc % 2 == 0) && (handle->upd_2d_blocking == 0)) ? 2 : 1;
465             handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
466           }
467           if (handle->desc.C == 512 && handle->desc.K == 512 && handle->desc.threads == 20) {
468             handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
469             handle->fwd_2d_blocking = 1;
470             handle->fwd_row_teams = 5;
471             handle->fwd_column_teams = 4;
472             handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
473             handle->bwd_2d_blocking = 0;
474             handle->bwd_row_teams = 1;
475             handle->bwd_column_teams = 1;
476             handle->upd_bf = ((handle->desc.N/handle->bn) % 15 == 0) ? 15 : 1;
477             handle->upd_2d_blocking = 0;
478             handle->upd_row_teams = 1;
479             handle->upd_column_teams = 1;
480             handle->ifm_subtasks = ((handle->bc % 4 == 0) && (handle->upd_2d_blocking == 0)) ? 4 : 1;
481             handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
482           }
483           if (handle->desc.C == 1024 && handle->desc.K == 1 && handle->desc.threads == 24) {
484             handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
485             handle->fwd_2d_blocking = 0;
486             handle->fwd_row_teams = 5;
487             handle->fwd_column_teams = 4;
488             handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
489             handle->bwd_2d_blocking = 0;
490             handle->bwd_row_teams = 5;
491             handle->bwd_column_teams = 4;
492             handle->upd_bf = 1/*((handle->desc.N/handle->bn) % 1 == 0) ? 1 : 1*/;
493             handle->upd_2d_blocking = 0;
494             handle->upd_row_teams = 5;
495             handle->upd_column_teams = 4;
496             handle->ifm_subtasks = ((handle->bc % 4 == 0) && (handle->upd_2d_blocking == 0)) ? 4 : 1;
497             handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
498           }
499           if (handle->desc.C == 1024 && handle->desc.K == 1 && handle->desc.threads == 20) {
500             handle->fwd_bf = 1/*((handle->desc.C/handle->bc) % 1 == 0) ? 1 : 1*/;
501             handle->fwd_2d_blocking = 0;
502             handle->fwd_row_teams = 6;
503             handle->fwd_column_teams = 4;
504             handle->bwd_bf = 1/*((handle->desc.K/handle->bk) % 1 == 0) ? 1 : 1*/;
505             handle->bwd_2d_blocking = 1;
506             handle->bwd_row_teams = 5;
507             handle->bwd_column_teams = 4;
508             handle->upd_bf = 1/*((handle->desc.N/handle->bn) % 1 == 0) ? 1 : 1*/;
509             handle->upd_2d_blocking = 0;
510             handle->upd_row_teams = 6;
511             handle->upd_column_teams = 4;
512             handle->ifm_subtasks = 1/*((handle->bc % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
513             handle->ofm_subtasks = 1/*((handle->bk % 1 == 0) && (handle->upd_2d_blocking == 0)) ? 1 : 1*/;
514           }
515 #endif
516         }
517       } else {
518         /* check that we cannot fuse */
519         if ( handle->desc.fuse_ops != LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE  ) {
520           free( handle );
521           *status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION;
522           return 0;
523         }
524 
525         /* we need to compute the memory layout given the */
526         if ( (handle->desc.C % 16 == 0) && (handle->desc.K % 16 == 0) ) {
527           if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) {
528             *status = libxsmm_dnn_get_feature_map_blocks( handle->desc.C, handle->desc.K,
529                 &(handle->ifmblock), &(handle->ofmblock), &(handle->fm_lp_block),
530                 LIBXSMM_DNN_DATATYPE_F32, LIBXSMM_DNN_DATATYPE_F32 );
531           } else if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) {
532             *status = libxsmm_dnn_get_feature_map_blocks( handle->desc.C, handle->desc.K,
533                 &(handle->ifmblock), &(handle->ofmblock), &(handle->fm_lp_block),
534                 handle->desc.datatype_in, handle->desc.datatype_out );
535           } else {
536             /* should not happen, not implemented */
537           }
538         } else if ( (handle->desc.C % 64 == 0) && (handle->desc.K == 1000) ) {
539           /* @TODO this a hack for the last FC layer */
540           handle->ifmblock = 64;
541           handle->fm_lp_block = 1;
542           handle->ofmblock = 10;
543         } else if ( (handle->desc.C % 16 == 0) && (handle->desc.K == 1000) ) {
544           /* @TODO this a hack for the last FC layer */
545           handle->ifmblock = 16;
546           handle->fm_lp_block = 1;
547           handle->ofmblock = 10;
548         } else {
549           *status = LIBXSMM_DNN_ERR_CREATE_HANDLE;
550           free( handle );
551           return 0;
552         }
553         /* compute the outer blocks */
554         handle->blocksifm = handle->desc.C / handle->ifmblock;
555         handle->blocksofm = handle->desc.K / handle->ofmblock;
556       }
557       /* create barrier */
558       handle->barrier = libxsmm_barrier_create(handle->desc.threads, 1);
559 
560       /* calculate scratch size */
561       if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) {
562         handle->scratch_size = sizeof(float) * ( ( (size_t)handle->desc.C * (size_t)handle->desc.N ) + ( (size_t)handle->desc.C * (size_t)handle->desc.K ) );
563       } else if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16)  ) {
564         /* Let's allocate maximum required scratch  */
565         size_t size_fwd = sizeof(float) * handle->desc.K * handle->desc.N;
566         /* In case of K = 1 we pad A and B to "bk=2" */
567         size_t size_bwd = (handle->desc.K != 1) ? ( sizeof(float) * handle->desc.C * handle->desc.N + sizeof(libxsmm_bfloat16) * handle->desc.C * handle->desc.K ) : ( sizeof(float) * handle->desc.C * handle->desc.N + sizeof(libxsmm_bfloat16) * handle->desc.C * 2 + sizeof(libxsmm_bfloat16) * 2 * handle->desc.N );
568         size_t size_upd = sizeof(float) * handle->desc.C * handle->desc.K + sizeof(libxsmm_bfloat16) * handle->desc.threads * handle->bk * handle->bc + sizeof(libxsmm_bfloat16) * (handle->desc.N * (handle->desc.C + handle->desc.K));
569         handle->scratch_size = LIBXSMM_MAX(LIBXSMM_MAX(size_fwd, size_bwd), size_upd);
570         handle->doutput_scratch_mark = handle->scratch_size;
571         handle->scratch_size += 2 * sizeof(libxsmm_bfloat16) * handle->desc.N *  handle->desc.K;
572       } else {
573         handle->scratch_size = sizeof(float) * ( (((size_t)handle->desc.C + (size_t)handle->desc.K) * (size_t)handle->desc.N) + ((size_t)handle->desc.C * (size_t)handle->desc.K) );
574       }
575       /* create code pointers in some special cases */
576       if ( ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) > 0) && ((handle->desc.filter_format & LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED) > 0)  ) {
577         if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) {
578           float alpha = 1.0f;
579           /* beta is set to 1 for ncnc kcck format because ifm is split into 2 blocks */
580           float beta  = 1.0f;
581           float zerobeta  = 0.0f;
582           int updflags = LIBXSMM_GEMM_FLAGS( 'N', 'T' );
583           /* For UPD kernels we consider subtasking... */
584           libxsmm_blasint M = handle->bk/handle->ofm_subtasks;
585           libxsmm_blasint N = handle->bc/handle->ifm_subtasks;
586 
587           libxsmm_blasint lda = (libxsmm_blasint)handle->bk;
588           libxsmm_blasint ldb = (libxsmm_blasint)handle->bc;
589           libxsmm_blasint ldc = (libxsmm_blasint)handle->bk;
590 
591           handle->gemm_fwd.xgemm.smrs = libxsmm_smmdispatch_reducebatch_strd(handle->bk, handle->bn, handle->bc, handle->bk*handle->bc*sizeof(float), handle->bc*handle->bn*sizeof(float), &lda, &ldb, &ldc, &alpha, &beta, NULL, NULL);
592           handle->gemm_fwd2.xgemm.smrs = libxsmm_smmdispatch_reducebatch_strd(handle->bk, handle->bn, handle->bc, handle->bk*handle->bc*sizeof(float), handle->bc*handle->bn*sizeof(float), &lda, &ldb, &ldc, &alpha, &zerobeta, NULL, NULL);
593           handle->gemm_bwd.xgemm.smrs = libxsmm_smmdispatch_reducebatch_strd(handle->bc, handle->bn, handle->bk, handle->bk*handle->bc*sizeof(float), handle->bk*handle->bn*sizeof(float), &ldb, &lda, &ldb, &alpha, &beta, NULL, NULL);
594           handle->gemm_bwd2.xgemm.smrs = libxsmm_smmdispatch_reducebatch_strd(handle->bc, handle->bn, handle->bk, handle->bk*handle->bc*sizeof(float), handle->bk*handle->bn*sizeof(float), &ldb, &lda, &ldb, &alpha, &zerobeta, NULL, NULL);
595 
596           /* Transpose kernel used for weight transpose in bwd pass */
597           tr_desc = libxsmm_trans_descriptor_init(&blob, sizeof(float), handle->bk, handle->bc, handle->bc);
598           handle->tr_kernel = libxsmm_dispatch_trans(tr_desc);
599 
600           /* update has different LDs */
601           lda = (libxsmm_blasint)handle->bk;
602           ldb = (libxsmm_blasint)handle->bc;
603           ldc = (libxsmm_blasint)handle->bk;
604           handle->gemm_upd.xgemm.smrs = libxsmm_smmdispatch_reducebatch_strd(M, N, handle->bn, handle->desc.K*handle->bn*sizeof(float), handle->desc.C*handle->bn*sizeof(float), &lda, &ldb, &ldc, &alpha, &beta, &updflags, NULL);
605           handle->gemm_upd2.xgemm.smrs = libxsmm_smmdispatch_reducebatch_strd(M, N, handle->bn, handle->desc.K*handle->bn*sizeof(float), handle->desc.C*handle->bn*sizeof(float), &lda, &ldb, &ldc, &alpha, &zerobeta, &updflags, NULL);
606         } else if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) {
607           float alpha = 1.0f;
608           float beta  = 1.0f;
609           float zerobeta  = 0.0f;
610           /* For UPD kernels we consider subtasking... */
611           libxsmm_blasint M = handle->bk/handle->ofm_subtasks;
612           libxsmm_blasint N = handle->bc/handle->ifm_subtasks;
613 
614           libxsmm_blasint lda = (libxsmm_blasint)handle->bk;
615           libxsmm_blasint ldb = (libxsmm_blasint)handle->bc;
616           libxsmm_blasint ldc = (libxsmm_blasint)handle->bk;
617 
618           handle->gemm_fwd.xgemm.bsmrs = libxsmm_bsmmdispatch_reducebatch_strd(handle->bk, handle->bn, handle->bc, handle->bk*handle->bc*sizeof(libxsmm_bfloat16), handle->bc*handle->bn*sizeof(libxsmm_bfloat16), &lda, &ldb, &ldc, &alpha, &beta, NULL, NULL);
619           handle->gemm_fwd2.xgemm.bmrs = libxsmm_bmmdispatch_reducebatch_strd(handle->bk, handle->bn, handle->bc, handle->bk*handle->bc*sizeof(libxsmm_bfloat16), handle->bc*handle->bn*sizeof(libxsmm_bfloat16), &lda, &ldb, &ldc, &alpha, &zerobeta, NULL, NULL);
620           handle->gemm_fwd3.xgemm.bmrs = libxsmm_bmmdispatch_reducebatch_strd(handle->bk, handle->bn, handle->bc, handle->bk*handle->bc*sizeof(libxsmm_bfloat16), handle->bc*handle->bn*sizeof(libxsmm_bfloat16), &lda, &ldb, &ldc, &alpha, &beta, NULL, NULL);
621           /* Special bwd kernels for K == 1 */
622           if (handle->desc.K == 1) {
623             libxsmm_blasint _bk = 2;
624             handle->gemm_bwd.xgemm.bsmrs = libxsmm_bsmmdispatch_reducebatch_strd(handle->bc, handle->bn, _bk, _bk*handle->bc*sizeof(libxsmm_bfloat16), _bk*handle->bn*sizeof(libxsmm_bfloat16), &ldb, &_bk, &ldb, &alpha, &beta, NULL, NULL);
625             handle->gemm_bwd2.xgemm.bmrs = libxsmm_bmmdispatch_reducebatch_strd(handle->bc, handle->bn, _bk, _bk*handle->bc*sizeof(libxsmm_bfloat16), _bk*handle->bn*sizeof(libxsmm_bfloat16), &ldb, &_bk, &ldb, &alpha, &zerobeta, NULL, NULL);
626           } else {
627             handle->gemm_bwd.xgemm.bsmrs = libxsmm_bsmmdispatch_reducebatch_strd(handle->bc, handle->bn, handle->bk, handle->bk*handle->bc*sizeof(libxsmm_bfloat16), handle->bk*handle->bn*sizeof(libxsmm_bfloat16), &ldb, &lda, &ldb, &alpha, &beta, NULL, NULL);
628             handle->gemm_bwd2.xgemm.bmrs = libxsmm_bmmdispatch_reducebatch_strd(handle->bc, handle->bn, handle->bk, handle->bk*handle->bc*sizeof(libxsmm_bfloat16), handle->bk*handle->bn*sizeof(libxsmm_bfloat16), &ldb, &lda, &ldb, &alpha, &zerobeta, NULL, NULL);
629           }
630           lda = (libxsmm_blasint)handle->bk;
631           ldb = (libxsmm_blasint)handle->bn;
632           ldc = (libxsmm_blasint)handle->bk;
633           handle->gemm_upd.xgemm.bsmrs = libxsmm_bsmmdispatch_reducebatch_strd(M, N, handle->bn, handle->bk*handle->bn*sizeof(libxsmm_bfloat16), handle->bc*handle->bn*sizeof(libxsmm_bfloat16), &lda, &ldb, &ldc, &alpha, &beta, NULL, NULL);
634           handle->gemm_upd2.xgemm.bmrs = libxsmm_bmmdispatch_reducebatch_strd(M, N, handle->bn, handle->bk*handle->bn*sizeof(libxsmm_bfloat16), handle->bc*handle->bn*sizeof(libxsmm_bfloat16), &lda, &ldb, &ldc, &alpha, &zerobeta, NULL, NULL);
635         } else {
636 
637         }
638       }
639     } else {
640       *status = LIBXSMM_DNN_ERR_CREATE_HANDLE;
641     }
642   } else {
643     *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
644   }
645 
646   return handle;
647 }
648 
649 
libxsmm_dnn_destroy_fullyconnected(const libxsmm_dnn_fullyconnected * handle)650 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_destroy_fullyconnected(const libxsmm_dnn_fullyconnected* handle) {
651   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
652 
653   if (0 != handle) {
654     /* Deallocate barrier */
655     if (handle->barrier != 0 ) { libxsmm_barrier_release((const libxsmm_barrier*)handle->barrier); }
656     /* deallocate handle structure */
657     free(/*remove constness*/(libxsmm_dnn_fullyconnected*)handle);
658   } else {
659     status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
660   }
661 
662   return status;
663 }
664 
665 
libxsmm_dnn_fullyconnected_create_tensor_datalayout(const libxsmm_dnn_fullyconnected * handle,const libxsmm_dnn_tensor_type type,libxsmm_dnn_err_t * status)666 LIBXSMM_API libxsmm_dnn_tensor_datalayout* libxsmm_dnn_fullyconnected_create_tensor_datalayout(const libxsmm_dnn_fullyconnected* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status) {
667   libxsmm_dnn_tensor_datalayout* layout;
668 
669   *status = LIBXSMM_DNN_SUCCESS;
670   layout = 0;
671 
672   if (handle != 0) {
673     layout = (libxsmm_dnn_tensor_datalayout*) malloc(sizeof(libxsmm_dnn_tensor_datalayout));
674 
675     if (layout != 0) {
676       memset(layout, 0, sizeof(libxsmm_dnn_tensor_datalayout));
677 
678       if ( (type == LIBXSMM_DNN_REGULAR_INPUT)     || (type == LIBXSMM_DNN_GRADIENT_INPUT)  || (type == LIBXSMM_DNN_INPUT)  ||
679            (type == LIBXSMM_DNN_REGULAR_OUTPUT)    || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT)    ) {
680         layout->format = handle->desc.buffer_format;
681         if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0) {
682           if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) {
683             layout->datatype = LIBXSMM_DNN_DATATYPE_F32;
684             layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype));
685             layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int));
686 
687             if (0 != layout->dim_type && 0 != layout->dim_size) {
688               layout->num_dims = 5;
689               layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
690               layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W;
691               layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H;
692               layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
693               layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
694               if ( (type == LIBXSMM_DNN_REGULAR_INPUT)     || (type == LIBXSMM_DNN_GRADIENT_INPUT)     || (type == LIBXSMM_DNN_INPUT)  ) {
695                 layout->dim_size[0] = handle->ifmblock;
696                 layout->dim_size[1] = 1;
697                 layout->dim_size[2] = 1;
698                 layout->dim_size[3] = handle->blocksifm;
699                 layout->dim_size[4] = handle->desc.N;
700               } else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) {
701                 layout->dim_size[0] = handle->ofmblock;
702                 layout->dim_size[1] = 1;
703                 layout->dim_size[2] = 1;
704                 layout->dim_size[3] = handle->blocksofm;
705                 layout->dim_size[4] = handle->desc.N;
706               } else { /* coverity[dead_error_begin] */
707                 free(layout->dim_type);
708                 free(layout->dim_size);
709                 free(layout);
710                 layout = 0; /* make sure a NULL is returned */
711                 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
712               }
713             } else {
714               free(layout);
715               layout = 0; /* make sure a NULL is returned */
716               *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
717             }
718           } else if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) {
719             if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) || (type == LIBXSMM_DNN_INPUT) ) {
720               layout->datatype = handle->desc.datatype_in;
721               layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype));
722               layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int));
723               if (0 != layout->dim_type && 0 != layout->dim_size) {
724                 layout->num_dims = 5;
725                 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
726                 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W;
727                 layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H;
728                 layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
729                 layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
730                 layout->dim_size[0] = handle->ifmblock;
731                 layout->dim_size[1] = 1;
732                 layout->dim_size[2] = 1;
733                 layout->dim_size[3] = handle->blocksifm;
734                 layout->dim_size[4] = handle->desc.N;
735               } else {
736                 free(layout->dim_type);
737                 free(layout->dim_size);
738                 free(layout);
739                 layout = 0; /* make sure a NULL is returned */
740                 *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
741               }
742             } else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) ) {
743               layout->datatype = handle->desc.datatype_out;
744               layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype));
745               layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int));
746               if (0 != layout->dim_type && 0 != layout->dim_size) {
747                 layout->num_dims = 5;
748                 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
749                 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W;
750                 layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H;
751                 layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
752                 layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
753                 layout->dim_size[0] = handle->ofmblock;
754                 layout->dim_size[1] = 1;
755                 layout->dim_size[2] = 1;
756                 layout->dim_size[3] = handle->blocksofm;
757                 layout->dim_size[4] = handle->desc.N;
758               } else {
759                 free(layout->dim_type);
760                 free(layout->dim_size);
761                 free(layout);
762                 layout = 0; /* make sure a NULL is returned */
763                 *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
764               }
765             } else {
766               free(layout);
767               layout = 0; /* make sure a NULL is returned */
768               *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
769             }
770           } else {
771             free(layout);
772             layout = 0; /* make sure a NULL is returned */
773             *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
774           }
775         } else if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NHWC) > 0) {
776           if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) ||
777               ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) ||
778               ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16))    ) {
779             layout->datatype = handle->desc.datatype_in;
780             layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype));
781             layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int));
782             if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
783               layout->num_dims = 4;
784               layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
785               layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_W;
786               layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_H;
787               layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
788               if ( (type == LIBXSMM_DNN_REGULAR_INPUT)     || (type == LIBXSMM_DNN_GRADIENT_INPUT)     || (type == LIBXSMM_DNN_INPUT)  )   {
789                 layout->dim_size[0] = handle->desc.C;
790                 layout->dim_size[1] = 1;
791                 layout->dim_size[2] = 1;
792                 layout->dim_size[3] = handle->desc.N;
793               } else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) || (type == LIBXSMM_DNN_OUTPUT) )   {
794                 layout->dim_size[0] = handle->desc.K;
795                 layout->dim_size[1] = 1;
796                 layout->dim_size[2] = 1;
797                 layout->dim_size[3] = handle->desc.N;
798               } else {
799                 free(layout->dim_type);
800                 free(layout->dim_size);
801                 free(layout);
802                 layout = 0; /* make sure a NULL is returned */
803                 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
804               }
805             }
806           } else {
807             free(layout);
808             layout = 0; /* make sure a NULL is returned */
809             *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
810           }
811         } else if ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) > 0) {
812           if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32)  && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32)) ||
813               ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16))    ) {
814             layout->datatype = handle->desc.datatype_in;
815             layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype));
816             layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int));
817 
818             if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
819               layout->num_dims = 4;
820 
821               if ( (type == LIBXSMM_DNN_REGULAR_INPUT) || (type == LIBXSMM_DNN_GRADIENT_INPUT) ) {
822                 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
823                 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
824                 layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
825                 layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
826                 layout->dim_size[0] = (unsigned int)handle->bc;
827                 layout->dim_size[1] = (unsigned int)handle->bn;
828                 layout->dim_size[2] = (unsigned int)(handle->desc.C / handle->bc);
829                 layout->dim_size[3] = (unsigned int)(handle->desc.N / handle->bn);
830               } else if ( (type == LIBXSMM_DNN_REGULAR_OUTPUT) || (type == LIBXSMM_DNN_GRADIENT_OUTPUT) ) {
831                 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
832                 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
833                 layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
834                 layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_N;
835                 layout->dim_size[0] = (unsigned int)handle->bk;
836                 layout->dim_size[1] = (unsigned int)handle->bn;
837                 layout->dim_size[2] = (unsigned int)(handle->desc.K / handle->bk);
838                 layout->dim_size[3] = (unsigned int)(handle->desc.N / handle->bn);
839               } else {
840                 free(layout->dim_type);
841                 free(layout->dim_size);
842                 free(layout);
843                 layout = 0; /* make sure a NULL is returned */
844                 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
845               }
846             } else {
847               free(layout);
848               layout = 0; /* make sure a NULL is returned */
849               *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
850             }
851           } else {
852             free(layout);
853             layout = 0; /* make sure a NULL is returned */
854             *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
855           }
856         } else {
857           free(layout);
858           layout = 0; /* make sure a NULL is returned */
859           *status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL;
860         }
861       } else if ( (type == LIBXSMM_DNN_REGULAR_FILTER)  || (type == LIBXSMM_DNN_GRADIENT_FILTER)  || (type == LIBXSMM_DNN_FILTER)  ) {
862         layout->format = handle->desc.filter_format;
863         layout->tensor_type = LIBXSMM_DNN_FILTER;
864 
865         if ((handle->desc.filter_format & LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) > 0) {
866           if ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) {
867             layout->datatype = handle->desc.datatype_in;
868             layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(6*sizeof(libxsmm_dnn_tensor_dimtype));
869             layout->dim_size = (unsigned int*) malloc(6*sizeof(unsigned int));
870             if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
871               layout->num_dims = 6;
872               layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
873               layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
874               layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_S;
875               layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_R;
876               layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
877               layout->dim_type[5] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
878               layout->dim_size[0] = handle->ofmblock;
879               layout->dim_size[1] = handle->ifmblock;
880               layout->dim_size[2] = 1;
881               layout->dim_size[3] = 1;
882               layout->dim_size[4] = handle->blocksifm;
883               layout->dim_size[5] = handle->blocksofm;
884             } else {
885               free(layout);
886               layout = 0; /* make sure a NULL is returned */
887               *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
888             }
889           } else if ( ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) ||
890               ( (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) )     ) {
891             layout->datatype = LIBXSMM_DNN_DATATYPE_BF16;
892             layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(7*sizeof(libxsmm_dnn_tensor_dimtype));
893             layout->dim_size = (unsigned int*) malloc(7*sizeof(unsigned int));
894             if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
895               layout->num_dims = 7;
896               layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
897               layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
898               layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
899               layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_S;
900               layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_R;
901               layout->dim_type[5] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
902               layout->dim_type[6] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
903               layout->dim_size[0] = handle->fm_lp_block;
904               layout->dim_size[1] = handle->ofmblock;
905               layout->dim_size[2] = handle->ifmblock/handle->fm_lp_block;
906               layout->dim_size[3] = 1;
907               layout->dim_size[4] = 1;
908               layout->dim_size[5] = handle->blocksifm;
909               layout->dim_size[6] = handle->blocksofm;
910             } else {
911               free(layout);
912               layout = 0; /* make sure a NULL is returned */
913               *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
914             }
915           } else {
916             free(layout);
917             layout = 0; /* make sure a NULL is returned */
918             *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
919           }
920         } else if ((handle->desc.filter_format & LIBXSMM_DNN_TENSOR_FORMAT_RSCK) > 0) {
921           if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32))   ||
922               ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32))  ||
923               ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16))    ) {
924             layout->datatype = handle->desc.datatype_in;
925             layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype));
926             layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int));
927             if (0 != layout->dim_type && 0 != layout->dim_size) {
928               layout->num_dims = 4;
929               layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
930               layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
931               layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_S;
932               layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_R;
933               layout->dim_size[0] = handle->ofmblock * handle->blocksofm;
934               layout->dim_size[1] = handle->ifmblock * handle->blocksifm;
935               layout->dim_size[2] = 1;
936               layout->dim_size[3] = 1;
937             } else {
938               free(layout);
939               layout = 0; /* make sure a NULL is returned */
940               *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
941             }
942           } else {
943             free(layout);
944             layout = 0; /* make sure a NULL is returned */
945             *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
946           }
947         } else if ((handle->desc.filter_format & LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED) > 0) {
948           if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) ) ) {
949             layout->datatype = LIBXSMM_DNN_DATATYPE_F32;
950             layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(4*sizeof(libxsmm_dnn_tensor_dimtype));
951             layout->dim_size = (unsigned int*) malloc(4*sizeof(unsigned int));
952 
953             if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
954               layout->num_dims = 4;
955 
956               if ( (type == LIBXSMM_DNN_REGULAR_FILTER) || (type == LIBXSMM_DNN_GRADIENT_FILTER) ) {
957                 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
958                 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
959                 layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
960                 layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
961                 layout->dim_size[0] = (unsigned int)handle->bk;
962                 layout->dim_size[1] = (unsigned int)handle->bc;
963                 layout->dim_size[2] = (unsigned int)(handle->desc.C / handle->bc);
964                 layout->dim_size[3] = (unsigned int)(handle->desc.K / handle->bk);
965               } else {
966                 free(layout->dim_type);
967                 free(layout->dim_size);
968                 free(layout);
969                 layout = 0; /* make sure a NULL is returned */
970                 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
971               }
972             } else {
973               free(layout);
974               layout = 0; /* make sure a NULL is returned */
975               *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
976             }
977           } else if ( ((handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16) && (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) ) {
978             layout->datatype = LIBXSMM_DNN_DATATYPE_BF16;
979             layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(5*sizeof(libxsmm_dnn_tensor_dimtype));
980             layout->dim_size = (unsigned int*) malloc(5*sizeof(unsigned int));
981 
982             if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
983               layout->num_dims = 5;
984 
985               if ( (type == LIBXSMM_DNN_REGULAR_FILTER) || (type == LIBXSMM_DNN_GRADIENT_FILTER) ) {
986                 layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
987                 layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
988                 layout->dim_type[2] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
989                 layout->dim_type[3] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
990                 layout->dim_type[4] = LIBXSMM_DNN_TENSOR_DIMTYPE_K;
991                 layout->dim_size[0] = (unsigned int)2;
992                 layout->dim_size[1] = (unsigned int)handle->bk;
993                 layout->dim_size[2] = (unsigned int)handle->bc/2;
994                 layout->dim_size[3] = (unsigned int)(handle->desc.C / handle->bc);
995                 layout->dim_size[4] = (unsigned int)(handle->desc.K / handle->bk);
996               } else {
997                 free(layout->dim_type);
998                 free(layout->dim_size);
999                 free(layout);
1000                 layout = 0; /* make sure a NULL is returned */
1001                 *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
1002               }
1003             } else {
1004               free(layout);
1005               layout = 0; /* make sure a NULL is returned */
1006               *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
1007             }
1008           } else {
1009             free(layout);
1010             layout = 0; /* make sure a NULL is returned */
1011             *status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
1012           }
1013         } else {
1014           free(layout);
1015           layout = 0; /* make sure a NULL is returned */
1016           *status = LIBXSMM_DNN_ERR_INVALID_FORMAT_GENERAL;
1017         }
1018       } else if ( (type == LIBXSMM_DNN_REGULAR_CHANNEL_BIAS) || (type == LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS) || (type == LIBXSMM_DNN_CHANNEL_BIAS) ) {
1019         layout->format = handle->desc.buffer_format;
1020         layout->tensor_type = LIBXSMM_DNN_CHANNEL_SCALAR;
1021 
1022         if ( ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) > 0) ) {
1023           if ( (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32) || (handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16) ) {
1024             layout->datatype = handle->desc.datatype_out;
1025             layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(2*sizeof(libxsmm_dnn_tensor_dimtype));
1026             layout->dim_size = (unsigned int*) malloc(2*sizeof(unsigned int));
1027 
1028             if (0 != layout->dim_type && 0 != layout->dim_size) { /* TODO: handle the error */
1029               layout->num_dims = 2;
1030               layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
1031               layout->dim_type[1] = LIBXSMM_DNN_TENSOR_DIMTYPE_C;
1032               layout->dim_size[0] = (unsigned int)handle->bk;
1033               layout->dim_size[1] = (unsigned int)(handle->desc.K / handle->bk);
1034             } else {
1035               free(layout->dim_type);
1036               free(layout->dim_size);
1037               free(layout);
1038               layout = 0; /* make sure a NULL is returned */
1039               *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
1040             }
1041           }
1042         } else {
1043           free(layout);
1044           layout = 0; /* make sure a NULL is returned */
1045           *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
1046         }
1047       } else if ( (type == LIBXSMM_DNN_RELU_MASK) ) {
1048         layout->format = handle->desc.buffer_format;
1049         layout->tensor_type = LIBXSMM_DNN_RELU_MASK;
1050 
1051         if ( ((handle->desc.buffer_format & LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) > 0) ) {
1052           layout->datatype = LIBXSMM_DNN_DATATYPE_I8;
1053           layout->dim_type = (libxsmm_dnn_tensor_dimtype*) malloc(1*sizeof(libxsmm_dnn_tensor_dimtype));
1054           layout->dim_size = (unsigned int*) malloc(1*sizeof(unsigned int));
1055 
1056           if (0 != layout->dim_type && 0 != layout->dim_size) {
1057             layout->num_dims = 1;
1058             layout->dim_type[0] = LIBXSMM_DNN_TENSOR_DIMTYPE_X;
1059             layout->dim_size[0] = handle->desc.N * handle->desc.K;
1060           } else {
1061             free(layout->dim_type);
1062             free(layout->dim_size);
1063             free(layout);
1064             layout = 0; /* make sure a NULL is returned */
1065             *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT_ARRAYS;
1066           }
1067         } else {
1068           free(layout);
1069           layout = 0; /* make sure a NULL is returned */
1070           *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
1071         }
1072       } else {
1073         free(layout);
1074         layout = 0; /* make sure a NULL is returned */
1075         *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
1076       }
1077     } else {
1078       *status = LIBXSMM_DNN_ERR_CREATE_LAYOUT;
1079     }
1080   }
1081   else {
1082     *status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
1083   }
1084 
1085   return layout;
1086 }
1087 
libxsmm_dnn_fullyconnected_get_scratch_size(const libxsmm_dnn_fullyconnected * handle,libxsmm_dnn_err_t * status)1088 LIBXSMM_API size_t libxsmm_dnn_fullyconnected_get_scratch_size(const libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_err_t* status) {
1089   size_t l_scratch_size = 0;
1090   *status = LIBXSMM_DNN_SUCCESS;
1091 
1092   if (0 != handle) {
1093     l_scratch_size = handle->scratch_size + 64; /* 64 byte extra in case the user code does not care about alignment */
1094   } else {
1095     *status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
1096   }
1097 
1098   return l_scratch_size;
1099 }
1100 
1101 
libxsmm_dnn_fullyconnected_get_scratch_ptr(const libxsmm_dnn_fullyconnected * handle,libxsmm_dnn_err_t * status)1102 LIBXSMM_API void* libxsmm_dnn_fullyconnected_get_scratch_ptr(const libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_err_t* status)
1103 {
1104   *status = LIBXSMM_DNN_SUCCESS;
1105 
1106   if (0 != handle) {
1107     return handle->scratch;
1108   } else {
1109     *status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
1110   }
1111 
1112   return 0;
1113 }
1114 
1115 
libxsmm_dnn_fullyconnected_bind_scratch(libxsmm_dnn_fullyconnected * handle,const void * scratch)1116 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_bind_scratch(libxsmm_dnn_fullyconnected* handle, const void* scratch) {
1117   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
1118   uintptr_t address = (uintptr_t)scratch;
1119   size_t offset = 0;
1120 
1121   if (scratch == 0) {
1122     status = LIBXSMM_DNN_ERR_SCRATCH_NOT_ALLOCED;
1123     return status;
1124   }
1125 
1126   if (0 != handle) {
1127     /* align the internal scratch buffer if needed */
1128     if (address % 64 == 0) {
1129       handle->scratch = (void*)address;
1130     } else {
1131       offset = (64 - address % 64);
1132       handle->scratch = (void*)(address+offset);
1133     }
1134   } else {
1135     status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
1136   }
1137 
1138   return status;
1139 }
1140 
1141 
libxsmm_dnn_fullyconnected_release_scratch(libxsmm_dnn_fullyconnected * handle)1142 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_release_scratch(libxsmm_dnn_fullyconnected* handle) {
1143   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
1144 
1145   if (0 != handle) {
1146     handle->scratch = 0;
1147   } else {
1148     status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
1149   }
1150 
1151   return status;
1152 }
1153 
1154 
libxsmm_dnn_fullyconnected_bind_tensor(libxsmm_dnn_fullyconnected * handle,const libxsmm_dnn_tensor * tensor,const libxsmm_dnn_tensor_type type)1155 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_bind_tensor(libxsmm_dnn_fullyconnected* handle, const libxsmm_dnn_tensor* tensor, const libxsmm_dnn_tensor_type type) {
1156   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
1157 
1158   /* check for tensor type */
1159   if ( (type != LIBXSMM_DNN_REGULAR_INPUT)        && (type != LIBXSMM_DNN_GRADIENT_INPUT)        &&
1160        (type != LIBXSMM_DNN_REGULAR_OUTPUT)       && (type != LIBXSMM_DNN_GRADIENT_OUTPUT)       &&
1161        (type != LIBXSMM_DNN_REGULAR_FILTER)       && (type != LIBXSMM_DNN_GRADIENT_FILTER)       &&
1162        (type != LIBXSMM_DNN_REGULAR_CHANNEL_BIAS) && (type != LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS) &&
1163        (type != LIBXSMM_DNN_RELU_MASK)  ) {
1164     status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
1165     return status;
1166   }
1167 
1168   if (handle != 0 && tensor != 0) {
1169     libxsmm_dnn_tensor_datalayout* handle_layout = libxsmm_dnn_fullyconnected_create_tensor_datalayout(handle, type, &status);
1170 
1171     if ( libxsmm_dnn_compare_tensor_datalayout(handle_layout, tensor->layout, &status) == 0 ) {
1172       if ( type == LIBXSMM_DNN_REGULAR_INPUT ) {
1173         handle->reg_input = (libxsmm_dnn_tensor*)tensor;
1174       } else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) {
1175         handle->grad_input = (libxsmm_dnn_tensor*)tensor;
1176       } else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) {
1177         handle->reg_output = (libxsmm_dnn_tensor*)tensor;
1178       } else if ( type == LIBXSMM_DNN_GRADIENT_OUTPUT ) {
1179         handle->grad_output = (libxsmm_dnn_tensor*)tensor;
1180       } else if ( type == LIBXSMM_DNN_REGULAR_FILTER ) {
1181         handle->reg_filter = (libxsmm_dnn_tensor*)tensor;
1182       } else if ( type == LIBXSMM_DNN_GRADIENT_FILTER ) {
1183         handle->grad_filter = (libxsmm_dnn_tensor*)tensor;
1184       } else if ( type == LIBXSMM_DNN_REGULAR_CHANNEL_BIAS ) {
1185         handle->reg_bias = (libxsmm_dnn_tensor*)tensor;
1186       } else if ( type == LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS ) {
1187         handle->grad_bias = (libxsmm_dnn_tensor*)tensor;
1188       } else if ( type == LIBXSMM_DNN_RELU_MASK ) {
1189         handle->relumask = (libxsmm_dnn_tensor*)tensor;
1190       } else {
1191         /* cannot happen */
1192       }
1193     } else {
1194       status = LIBXSMM_DNN_ERR_MISMATCH_TENSOR;
1195     }
1196 
1197     libxsmm_dnn_destroy_tensor_datalayout( handle_layout );
1198   }
1199   else {
1200     status = LIBXSMM_DNN_ERR_INVALID_HANDLE_TENSOR;
1201   }
1202 
1203   return status;
1204 }
1205 
1206 
libxsmm_dnn_fullyconnected_get_tensor(libxsmm_dnn_fullyconnected * handle,const libxsmm_dnn_tensor_type type,libxsmm_dnn_err_t * status)1207 LIBXSMM_API libxsmm_dnn_tensor* libxsmm_dnn_fullyconnected_get_tensor(libxsmm_dnn_fullyconnected* handle, const libxsmm_dnn_tensor_type type, libxsmm_dnn_err_t* status) {
1208   libxsmm_dnn_tensor* return_tensor = 0;
1209 
1210   *status = LIBXSMM_DNN_SUCCESS;
1211 
1212   /* check for tensor type */
1213   if ( (type != LIBXSMM_DNN_REGULAR_INPUT)        && (type != LIBXSMM_DNN_GRADIENT_INPUT)        &&
1214        (type != LIBXSMM_DNN_REGULAR_OUTPUT)       && (type != LIBXSMM_DNN_GRADIENT_OUTPUT)       &&
1215        (type != LIBXSMM_DNN_REGULAR_FILTER)       && (type != LIBXSMM_DNN_GRADIENT_FILTER)       &&
1216        (type != LIBXSMM_DNN_REGULAR_CHANNEL_BIAS) && (type != LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS) &&
1217        (type != LIBXSMM_DNN_RELU_MASK)  ) {
1218     *status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
1219     return return_tensor;
1220   }
1221 
1222   if (handle != 0) {
1223     if ( type == LIBXSMM_DNN_REGULAR_INPUT ) {
1224       return_tensor = handle->reg_input;
1225     } else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) {
1226       return_tensor = handle->grad_input;
1227     } else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) {
1228       return_tensor = handle->reg_output;
1229     } else if ( type == LIBXSMM_DNN_GRADIENT_OUTPUT ) {
1230       return_tensor = handle->grad_output;
1231     } else if ( type == LIBXSMM_DNN_REGULAR_FILTER ) {
1232       return_tensor = handle->reg_filter;
1233     } else if ( type == LIBXSMM_DNN_GRADIENT_FILTER ) {
1234       return_tensor = handle->grad_filter;
1235     } else if ( type == LIBXSMM_DNN_REGULAR_CHANNEL_BIAS ) {
1236       return_tensor = handle->reg_bias;
1237     } else if ( type == LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS ) {
1238       return_tensor = handle->grad_bias;
1239     } else if ( type == LIBXSMM_DNN_RELU_MASK ) {
1240       return_tensor = handle->relumask;
1241     } else {
1242       /* cannot happen */
1243     }
1244   } else {
1245     *status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
1246   }
1247 
1248   return return_tensor;
1249 }
1250 
1251 
libxsmm_dnn_fullyconnected_release_tensor(libxsmm_dnn_fullyconnected * handle,const libxsmm_dnn_tensor_type type)1252 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_release_tensor(libxsmm_dnn_fullyconnected* handle, const libxsmm_dnn_tensor_type type) {
1253   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
1254 
1255   /* check for tensor type */
1256   if ( (type != LIBXSMM_DNN_REGULAR_INPUT)        && (type != LIBXSMM_DNN_GRADIENT_INPUT)        &&
1257        (type != LIBXSMM_DNN_REGULAR_OUTPUT)       && (type != LIBXSMM_DNN_GRADIENT_OUTPUT)       &&
1258        (type != LIBXSMM_DNN_REGULAR_FILTER)       && (type != LIBXSMM_DNN_GRADIENT_FILTER)       &&
1259        (type != LIBXSMM_DNN_REGULAR_CHANNEL_BIAS) && (type != LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS) &&
1260        (type != LIBXSMM_DNN_RELU_MASK)  ) {
1261     status = LIBXSMM_DNN_ERR_UNKNOWN_TENSOR_TYPE;
1262     return status;
1263   }
1264 
1265   if (handle != 0) {
1266     if ( type == LIBXSMM_DNN_REGULAR_INPUT ) {
1267       handle->reg_input = 0;
1268     } else if ( type == LIBXSMM_DNN_GRADIENT_INPUT ) {
1269       handle->grad_input = 0;
1270     } else if ( type == LIBXSMM_DNN_REGULAR_OUTPUT ) {
1271       handle->reg_output = 0;
1272     } else if ( type == LIBXSMM_DNN_GRADIENT_OUTPUT ) {
1273       handle->grad_output = 0;
1274     } else if ( type == LIBXSMM_DNN_REGULAR_FILTER ) {
1275       handle->reg_filter = 0;
1276     } else if ( type == LIBXSMM_DNN_GRADIENT_FILTER ) {
1277       handle->grad_filter = 0;
1278     } else if ( type == LIBXSMM_DNN_REGULAR_CHANNEL_BIAS ) {
1279       handle->reg_bias = 0;
1280     } else if ( type == LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS ) {
1281       handle->grad_bias = 0;
1282     } else if ( type == LIBXSMM_DNN_RELU_MASK ) {
1283       handle->relumask = 0;
1284     } else {
1285       /* cannot happen */
1286     }
1287   } else {
1288     status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
1289   }
1290 
1291   return status;
1292 }
1293 
1294 
libxsmm_dnn_fullyconnected_execute_st(libxsmm_dnn_fullyconnected * handle,libxsmm_dnn_compute_kind kind,int start_thread,int tid)1295 LIBXSMM_API libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_execute_st(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind,
1296     /*unsigned*/int start_thread, /*unsigned*/int tid) {
1297   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
1298   LIBXSMM_UNUSED( start_thread );
1299   LIBXSMM_UNUSED( tid );
1300 
1301   if (0 != handle) {
1302     switch (kind) {
1303       case LIBXSMM_DNN_COMPUTE_KIND_FWD: {
1304         if ( (handle->desc.buffer_format == LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) && (handle->desc.filter_format == LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) ) {
1305           status = libxsmm_dnn_fullyconnected_st_fwd_custom( handle, start_thread, tid );
1306         } else if ( (handle->desc.buffer_format == LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) && (handle->desc.filter_format == LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED) ) {
1307           status = libxsmm_dnn_fullyconnected_st_fwd_ncnc_kcck( handle, start_thread, tid );
1308         } else {
1309           status = LIBXSMM_DNN_ERR_INVALID_FORMAT_FC;
1310         }
1311       } break;
1312       case LIBXSMM_DNN_COMPUTE_KIND_BWD:
1313       case LIBXSMM_DNN_COMPUTE_KIND_UPD:
1314       case LIBXSMM_DNN_COMPUTE_KIND_BWDUPD: {
1315         if ( (handle->desc.buffer_format == LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) && (handle->desc.filter_format == LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM) ) {
1316           status = libxsmm_dnn_fullyconnected_st_bwdupd_custom( handle, kind, start_thread, tid );
1317         } else if ( (handle->desc.buffer_format == LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED) && (handle->desc.filter_format == LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED) ) {
1318           status = libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck( handle, kind, start_thread, tid );
1319         } else {
1320           status = LIBXSMM_DNN_ERR_INVALID_FORMAT_FC;
1321         }
1322       } break;
1323       default: {
1324         status = LIBXSMM_DNN_ERR_INVALID_KIND;
1325       }
1326     }
1327   }
1328   else {
1329     status = LIBXSMM_DNN_ERR_INVALID_HANDLE;
1330   }
1331 
1332   return status;
1333 }
1334 
1335