1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Alexander Heinecke (Intel Corp.)
10 ******************************************************************************/
11 #include <libxsmm.h>
12 
13 #include <stdlib.h>
14 #include <string.h>
15 #include <stdio.h>
16 #include <math.h>
17 #if defined(_OPENMP)
18 # include <omp.h>
19 #endif
20 
21 /* include c-based dnn library */
22 #include "../common/dnn_common.h"
23 
24 #define CHKERR_LIBXSMM_DNN(A) { const int chkerr_libxsmm_dnn_ = A; if (LIBXSMM_DNN_SUCCESS != chkerr_libxsmm_dnn_) { \
25   fprintf(stderr, "%s\n", libxsmm_dnn_get_error(chkerr_libxsmm_dnn_)); global_status = chkerr_libxsmm_dnn_; } \
26 }
27 
main(int argc,char * argv[])28 int main(int argc, char* argv[])
29 {
30   float *naive_input, *naive_output, *naive_filter, *naive_delinput, *naive_deloutput, *naive_delfilter, *naive_bias, *naive_delbias;
31   libxsmm_bfloat16 *naive_input_bf16, *naive_filter_bf16, *naive_output_bf16, *naive_delinput_bf16, *naive_delfilter_bf16, *naive_deloutput_bf16, *naive_bias_bf16, *naive_delbias_bf16;
32   float *naive_libxsmm_output_f32, *naive_libxsmm_delinput_f32, *naive_libxsmm_delfilter_f32, *naive_libxsmm_delbias_f32;
33   libxsmm_bfloat16 *naive_libxsmm_output_bf16, *naive_libxsmm_delinput_bf16, *naive_libxsmm_delfilter_bf16, *naive_libxsmm_delbias_bf16;
34   libxsmm_bfloat16 *input_libxsmm, *filter_libxsmm, *delinput_libxsmm, *delfilter_libxsmm, *output_libxsmm, *deloutput_libxsmm, *bias_libxsmm, *delbias_libxsmm;
35   unsigned char *relumask_libxsmm;
36 
37   naive_fullyconnected_t naive_param;
38   void* scratch;
39   size_t scratch_size = 0;
40 
41   /* some parameters we can overwrite via cli,
42      default is some inner layer of overfeat */
43   int iters = 100;         /* repetitions of benchmark */
44   int nImg = 32;          /* mini-batch size, "N" */
45   int nIFm = 256;          /* number of input feature maps, "C" */
46   int nOFm = 256;          /* number of input feature maps, "C" */
47   int fuse_type = 0;      /* 0: nothing fused, 1: relu fused, 2: elementwise fused, 3: relu and elementwise fused */
48   char type = 'A';        /* 'A': ALL, 'F': FP, 'B': BP, 'U', WU */
49   char format = 'B';
50   int bn = 32;
51   int bk = 32;
52   int bc = 32;
53 
54   const char *const env_check = getenv("CHECK");
55   const double check = LIBXSMM_ABS(0 == env_check ? 1 : atof(env_check));
56 
57 #if defined(_OPENMP)
58   int nThreads = omp_get_max_threads(); /* number of threads */
59 #else
60   int nThreads = 1; /* number of threads */
61 #endif
62 
63   unsigned long long l_start, l_end;
64   double l_total = 0.0;
65   double gflop = 0.0;
66   int i;
67 
68   libxsmm_dnn_fullyconnected_desc fullyconnected_desc;
69   libxsmm_dnn_fullyconnected* libxsmm_handle;
70   libxsmm_dnn_tensor*  libxsmm_input;
71   libxsmm_dnn_tensor*  libxsmm_delinput;
72   libxsmm_dnn_tensor*  libxsmm_output;
73   libxsmm_dnn_tensor*  libxsmm_deloutput;
74   libxsmm_dnn_tensor*  libxsmm_filter;
75   libxsmm_dnn_tensor*  libxsmm_delfilter;
76   libxsmm_dnn_tensor*  libxsmm_bias;
77   libxsmm_dnn_tensor*  libxsmm_delbias;
78   libxsmm_dnn_tensor*  libxsmm_relumask;
79   libxsmm_dnn_tensor_datalayout* libxsmm_layout;
80   libxsmm_dnn_err_t status;
81   libxsmm_dnn_err_t global_status = LIBXSMM_DNN_SUCCESS;
82 
83   libxsmm_matdiff_info norms_fwd, norms_bwd, norms_upd, diff;
84   libxsmm_matdiff_clear(&norms_fwd);
85   libxsmm_matdiff_clear(&norms_bwd);
86   libxsmm_matdiff_clear(&norms_upd);
87   libxsmm_matdiff_clear(&diff);
88 
89   if (argc > 1 && !strncmp(argv[1], "-h", 3)) {
90     printf("Usage: %s iters nImg nIFm nOFm fuse_type type format\n", argv[0]);
91     return 0;
92   }
93   libxsmm_rng_set_seed(1);
94 
95   /* reading new values from cli */
96   i = 1;
97   if (argc > i) iters      = atoi(argv[i++]);
98   if (argc > i) nImg       = atoi(argv[i++]);
99   if (argc > i) nIFm       = atoi(argv[i++]);
100   if (argc > i) nOFm       = atoi(argv[i++]);
101   if (argc > i) fuse_type  = atoi(argv[i++]);
102   if (argc > i) type       = *(argv[i++]);
103   if (argc > i) format     = *(argv[i++]);
104   if (argc > i) bn         = atoi(argv[i++]);
105   if (argc > i) bk         = atoi(argv[i++]);
106   if (argc > i) bc         = atoi(argv[i++]);
107 
108   /* These are tuning parameters to be attached to the perfdump string  */
109 #if 0
110   int fwd_bf = atoi(getenv("FWD_BF"));
111   int bwd_bf = atoi(getenv("BWD_BF"));
112   int upd_bf = atoi(getenv("UPD_BF"));
113   int fwd_2d_blocking = atoi(getenv("FWD_2D_BLOCKING"));
114   int bwd_2d_blocking = atoi(getenv("BWD_2D_BLOCKING"));
115   int upd_2d_blocking = atoi(getenv("UPD_2D_BLOCKING"));
116   int fwd_row_teams = atoi(getenv("FWD_ROW_TEAMS"));
117   int fwd_column_teams = atoi(getenv("FWD_COLUMN_TEAMS"));
118   int bwd_row_teams = atoi(getenv("BWD_ROW_TEAMS"));
119   int bwd_column_teams = atoi(getenv("BWD_COLUMN_TEAMS"));
120   int upd_row_teams = atoi(getenv("UPD_ROW_TEAMS"));
121   int upd_column_teams = atoi(getenv("UPD_COLUMN_TEAMS"));
122   int ifm_subtasks = atoi(getenv("IFM_SUBTASKS"));
123   int ofm_subtasks = atoi(getenv("OFM_SUBTASKS"));
124 #endif
125   int fwd_bf = 1;
126   int bwd_bf = 1;
127   int upd_bf = 1;
128   int fwd_2d_blocking = 1;
129   int bwd_2d_blocking = 1;
130   int upd_2d_blocking = 1;
131   int fwd_row_teams = 1;
132   int fwd_column_teams = 1;
133   int bwd_row_teams = 1;
134   int bwd_column_teams = 1;
135   int upd_row_teams = 1;
136   int upd_column_teams = 1;
137   int ifm_subtasks = 1;
138   int ofm_subtasks = 1;
139 
140   if ( nImg % bn != 0 ) {
141     bn = nImg;
142   }
143   if ( nIFm % bc != 0 ) {
144     bc = nIFm;
145   }
146   if ( nOFm % bk != 0 ) {
147     bk = nOFm;
148   }
149 
150   if (type != 'A' && type != 'F' && type != 'B' && type != 'U' && type != 'M') {
151     printf("type needs to be 'A' (All), 'F' (FP only), 'B' (BP only), 'U' (UP only). 'M' (BPUP-fused only)\n");
152     return -1;
153   }
154   if ( (fuse_type < 0) || (fuse_type > 5) ) {
155     printf("fuse type needs to be 0 (None), 1 (Bias), 2 (ReLU), 3 (Sigmoid), 4 (Bias+ReLU), 5 (Bias+Sigmoid)\n");
156     return -1;
157   }
158   if (format != 'L' && format != 'B') {
159     printf("format needs to be 'L' (libxsmm) or 'B' (for locked NCNC KCCK)\n");
160     return -1;
161   }
162 
163   /* set struct for naive convolution */
164   naive_param.N = nImg;
165   naive_param.C = nIFm;
166   naive_param.K = nOFm;
167   naive_param.fuse_type = fuse_type;
168 
169 #if defined(__SSE3__)
170   _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON);
171   _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);
172   _MM_SET_ROUNDING_MODE(_MM_ROUND_NEAREST);
173 #endif
174 
175   /* print some summary */
176   printf("##########################################\n");
177   printf("#          Setting Up (Common)           #\n");
178   printf("##########################################\n");
179   printf("PARAMS: N:%d  C:%d  K:%d\n", nImg, nIFm, nOFm);
180   printf("PARAMS: ITERS:%d", iters); if (LIBXSMM_FEQ(0, check)) printf("  Threads:%d\n", nThreads); else printf("\n");
181   printf("SIZE Input  (MB): %10.2f MiB\n", (double)(nImg*nIFm*sizeof(libxsmm_bfloat16))/(1024.0*1024.0) );
182   printf("SIZE Output (MB): %10.2f MiB\n", (double)(nImg*nOFm*sizeof(float))/(1024.0*1024.0) );
183   printf("SIZE Input   (1): %10.2f MiB\n", (double)(1*nIFm*   sizeof(libxsmm_bfloat16))/(1024.0*1024.0) );
184   printf("SIZE Output  (1): %10.2f MiB\n", (double)(1*nOFm*   sizeof(float))/(1024.0*1024.0) );
185   printf("SIZE Filter     : %10.2f MiB\n", (double)(nIFm*nOFm*sizeof(libxsmm_bfloat16))/(1024.0*1024.0) );
186 
187   /* allocate data */
188   naive_input                 = (float*)libxsmm_aligned_malloc( nImg*nIFm*sizeof(float), 2097152);
189   naive_delinput              = (float*)libxsmm_aligned_malloc( nImg*nIFm*sizeof(float), 2097152);
190   naive_output                = (float*)libxsmm_aligned_malloc( nImg*nOFm*sizeof(float), 2097152);
191   naive_deloutput             = (float*)libxsmm_aligned_malloc( nImg*nOFm*sizeof(float), 2097152);
192   naive_filter                = (float*)libxsmm_aligned_malloc( nIFm*nOFm*sizeof(float), 2097152);
193   naive_delfilter             = (float*)libxsmm_aligned_malloc( nIFm*nOFm*sizeof(float), 2097152);
194   naive_bias                  = (float*)libxsmm_aligned_malloc( nOFm     *sizeof(float), 2097152);
195   naive_delbias               = (float*)libxsmm_aligned_malloc( nOFm     *sizeof(float), 2097152);
196 
197   naive_input_bf16            = (libxsmm_bfloat16*)libxsmm_aligned_malloc( nImg*nIFm*sizeof(libxsmm_bfloat16), 2097152);
198   naive_delinput_bf16         = (libxsmm_bfloat16*)libxsmm_aligned_malloc( nImg*nIFm*sizeof(libxsmm_bfloat16), 2097152);
199   naive_output_bf16           = (libxsmm_bfloat16*)libxsmm_aligned_malloc( nImg*nOFm*sizeof(libxsmm_bfloat16), 2097152);
200   naive_deloutput_bf16        = (libxsmm_bfloat16*)libxsmm_aligned_malloc( nImg*nOFm*sizeof(libxsmm_bfloat16), 2097152);
201   naive_filter_bf16           = (libxsmm_bfloat16*)libxsmm_aligned_malloc( nIFm*nOFm*sizeof(libxsmm_bfloat16), 2097152);
202   naive_delfilter_bf16        = (libxsmm_bfloat16*)libxsmm_aligned_malloc( nIFm*nOFm*sizeof(libxsmm_bfloat16), 2097152);
203   naive_bias_bf16             = (libxsmm_bfloat16*)libxsmm_aligned_malloc( nOFm     *sizeof(libxsmm_bfloat16), 2097152);
204   naive_delbias_bf16          = (libxsmm_bfloat16*)libxsmm_aligned_malloc( nOFm     *sizeof(libxsmm_bfloat16), 2097152);
205 
206   naive_libxsmm_output_bf16   = (libxsmm_bfloat16*)libxsmm_aligned_malloc( nImg*nOFm*sizeof(libxsmm_bfloat16), 2097152);
207   naive_libxsmm_delinput_bf16 = (libxsmm_bfloat16*)libxsmm_aligned_malloc( nImg*nIFm*sizeof(libxsmm_bfloat16), 2097152);
208   naive_libxsmm_delfilter_bf16= (libxsmm_bfloat16*)libxsmm_aligned_malloc( nIFm*nOFm*sizeof(libxsmm_bfloat16), 2097152);
209   naive_libxsmm_delbias_bf16  = (libxsmm_bfloat16*)libxsmm_aligned_malloc( nOFm     *sizeof(libxsmm_bfloat16), 2097152);
210   naive_libxsmm_output_f32    = (float*)libxsmm_aligned_malloc( nImg*nOFm*sizeof(float), 2097152);
211   naive_libxsmm_delinput_f32  = (float*)libxsmm_aligned_malloc( nImg*nIFm*sizeof(float), 2097152);
212   naive_libxsmm_delfilter_f32 = (float*)libxsmm_aligned_malloc( nIFm*nOFm*sizeof(float), 2097152);
213   naive_libxsmm_delbias_f32   = (float*)libxsmm_aligned_malloc( nOFm     *sizeof(float), 2097152);
214 
215   input_libxsmm               = (libxsmm_bfloat16*)libxsmm_aligned_malloc( nImg*nIFm*sizeof(libxsmm_bfloat16), 2097152);
216   delinput_libxsmm            = (libxsmm_bfloat16*)libxsmm_aligned_malloc( nImg*nIFm*sizeof(libxsmm_bfloat16), 2097152);
217   output_libxsmm              = (libxsmm_bfloat16*)libxsmm_aligned_malloc( nImg*nOFm*sizeof(libxsmm_bfloat16), 2097152);
218   deloutput_libxsmm           = (libxsmm_bfloat16*)libxsmm_aligned_malloc( nImg*nOFm*sizeof(libxsmm_bfloat16), 2097152);
219   filter_libxsmm              = (libxsmm_bfloat16*)libxsmm_aligned_malloc( nIFm*nOFm*sizeof(libxsmm_bfloat16), 2097152);
220   delfilter_libxsmm           = (libxsmm_bfloat16*)libxsmm_aligned_malloc( nIFm*nOFm*sizeof(libxsmm_bfloat16), 2097152);
221   bias_libxsmm                =  (libxsmm_bfloat16*)libxsmm_aligned_malloc( nOFm     *sizeof(libxsmm_bfloat16), 2097152);
222   delbias_libxsmm             =  (libxsmm_bfloat16*)libxsmm_aligned_malloc( nOFm     *sizeof(libxsmm_bfloat16), 2097152);
223   relumask_libxsmm            =  (unsigned char*)libxsmm_aligned_malloc( nImg*nOFm*sizeof(unsigned char), 2097152);
224 
225   /* initialize data */
226   init_buf( naive_input,     nImg*nIFm, 0, 0 );
227   init_buf( naive_delinput,  nImg*nIFm, 0, 0 );
228   init_buf( naive_output,    nImg*nOFm, 0, 0 );
229   init_buf( naive_deloutput, nImg*nOFm, 0, 0 );
230   init_buf( naive_filter,    nIFm*nOFm, 0, 0 );
231   init_buf( naive_delfilter, nIFm*nOFm, 0, 0 );
232   init_buf( naive_bias,      nOFm,      0, 0 );
233   init_buf( naive_delbias,   nOFm,      0, 0 );
234 
235   libxsmm_rne_convert_fp32_bf16( naive_input,     naive_input_bf16,     nImg*nIFm );
236   libxsmm_rne_convert_fp32_bf16( naive_delinput,  naive_delinput_bf16,  nImg*nIFm );
237   libxsmm_rne_convert_fp32_bf16( naive_output,    naive_output_bf16,    nImg*nOFm );
238   libxsmm_rne_convert_fp32_bf16( naive_deloutput, naive_deloutput_bf16, nImg*nOFm );
239   libxsmm_rne_convert_fp32_bf16( naive_filter,    naive_filter_bf16,    nIFm*nOFm );
240   libxsmm_rne_convert_fp32_bf16( naive_delfilter, naive_delfilter_bf16, nIFm*nOFm );
241   libxsmm_rne_convert_fp32_bf16( naive_bias,      naive_bias_bf16,      nOFm );
242   libxsmm_rne_convert_fp32_bf16( naive_delbias,   naive_delbias_bf16,   nOFm );
243 
244   if (LIBXSMM_NEQ(0, check)) {
245     printf("##########################################\n");
246     printf("#         Computing Reference ...        #\n");
247     printf("##########################################\n");
248     if (type == 'A' || type == 'F') {
249       naive_fullyconnected_fused_fp(&naive_param, naive_input, naive_output, naive_filter, naive_bias);
250     }
251     if (type == 'A' || type == 'B' || type == 'M') {
252       naive_fullyconnected_fused_bp(&naive_param, naive_delinput, naive_deloutput, naive_filter, naive_delbias, naive_output);
253     }
254     if (type == 'A' || type == 'U' || type == 'M') {
255       naive_fullyconnected_wu(&naive_param, naive_input, naive_deloutput, naive_delfilter);
256     }
257     printf("##########################################\n");
258     printf("#      Computing Reference ... done      #\n");
259     printf("##########################################\n");
260   }
261 
262   if (format == 'A' || format == 'B') {
263     printf("\n");
264     printf("##########################################\n");
265     printf("#      Setting Up  (custom-Storage)      #\n");
266     printf("##########################################\n");
267 
268     /* setup LIBXSMM handle */
269     fullyconnected_desc.N = nImg;
270     fullyconnected_desc.C = nIFm;
271     fullyconnected_desc.K = nOFm;
272     fullyconnected_desc.bn = bn;
273     fullyconnected_desc.bk = bk;
274     fullyconnected_desc.bc = bc;
275     fullyconnected_desc.threads = nThreads;
276     fullyconnected_desc.datatype_in = LIBXSMM_DNN_DATATYPE_BF16;
277     fullyconnected_desc.datatype_out = LIBXSMM_DNN_DATATYPE_BF16;
278     fullyconnected_desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NCPACKED;
279     fullyconnected_desc.filter_format = LIBXSMM_DNN_TENSOR_FORMAT_CKPACKED;
280     if ( fuse_type == 0 ) {
281       fullyconnected_desc.fuse_ops = LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE;
282     } else if ( fuse_type == 1 ) {
283       fullyconnected_desc.fuse_ops = LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS;
284     } else if ( fuse_type == 2 ) {
285       fullyconnected_desc.fuse_ops = LIBXSMM_DNN_FULLYCONNECTED_FUSE_RELU;
286     } else if ( fuse_type == 3 ) {
287       fullyconnected_desc.fuse_ops = LIBXSMM_DNN_FULLYCONNECTED_FUSE_SIGMOID;
288     } else if ( fuse_type == 4 ) {
289       fullyconnected_desc.fuse_ops = LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_RELU;
290     } else if ( fuse_type == 5 ) {
291       fullyconnected_desc.fuse_ops = LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_SIGMOID;
292     } else {
293       /* cannot happen */
294     }
295 
296     libxsmm_handle = libxsmm_dnn_create_fullyconnected( fullyconnected_desc, &status );
297     CHKERR_LIBXSMM_DNN( status );
298 
299     /* setup LIBXSMM buffers */
300     libxsmm_layout = libxsmm_dnn_fullyconnected_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_REGULAR_INPUT, &status ); CHKERR_LIBXSMM_DNN( status );
301     libxsmm_input  = libxsmm_dnn_link_tensor( libxsmm_layout, input_libxsmm, &status ); CHKERR_LIBXSMM_DNN( status );
302     printf("inner activation blocking: %i\n", libxsmm_layout->dim_size[0] );
303     libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
304 
305     libxsmm_layout = libxsmm_dnn_fullyconnected_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_GRADIENT_INPUT, &status ); CHKERR_LIBXSMM_DNN( status );
306     libxsmm_delinput  = libxsmm_dnn_link_tensor( libxsmm_layout, delinput_libxsmm, &status ); CHKERR_LIBXSMM_DNN( status );
307     libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
308 
309     libxsmm_layout = libxsmm_dnn_fullyconnected_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_REGULAR_OUTPUT, &status ); CHKERR_LIBXSMM_DNN( status );
310     libxsmm_output  = libxsmm_dnn_link_tensor( libxsmm_layout, output_libxsmm, &status ); CHKERR_LIBXSMM_DNN( status );
311     libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
312 
313     libxsmm_layout = libxsmm_dnn_fullyconnected_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_GRADIENT_OUTPUT, &status ); CHKERR_LIBXSMM_DNN( status );
314     libxsmm_deloutput  = libxsmm_dnn_link_tensor( libxsmm_layout, deloutput_libxsmm, &status ); CHKERR_LIBXSMM_DNN( status );
315     libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
316 
317     libxsmm_layout = libxsmm_dnn_fullyconnected_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_REGULAR_FILTER, &status ); CHKERR_LIBXSMM_DNN( status );
318     libxsmm_filter  = libxsmm_dnn_link_tensor( libxsmm_layout, filter_libxsmm, &status ); CHKERR_LIBXSMM_DNN( status );
319     libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
320 
321     libxsmm_layout = libxsmm_dnn_fullyconnected_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_GRADIENT_FILTER, &status ); CHKERR_LIBXSMM_DNN( status );
322     libxsmm_delfilter  = libxsmm_dnn_link_tensor( libxsmm_layout, delfilter_libxsmm, &status ); CHKERR_LIBXSMM_DNN( status );
323     libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
324 
325     libxsmm_layout = libxsmm_dnn_fullyconnected_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_REGULAR_CHANNEL_BIAS, &status ); CHKERR_LIBXSMM_DNN( status );
326     libxsmm_bias = libxsmm_dnn_link_tensor( libxsmm_layout, bias_libxsmm, &status ); CHKERR_LIBXSMM_DNN( status );
327     libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
328 
329     libxsmm_layout = libxsmm_dnn_fullyconnected_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS, &status ); CHKERR_LIBXSMM_DNN( status );
330     libxsmm_delbias  = libxsmm_dnn_link_tensor( libxsmm_layout, delbias_libxsmm, &status ); CHKERR_LIBXSMM_DNN( status );
331     libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
332 
333     libxsmm_layout = libxsmm_dnn_fullyconnected_create_tensor_datalayout( libxsmm_handle, LIBXSMM_DNN_RELU_MASK, &status ); CHKERR_LIBXSMM_DNN( status );
334     libxsmm_relumask  = libxsmm_dnn_link_tensor( libxsmm_layout, relumask_libxsmm, &status ); CHKERR_LIBXSMM_DNN( status );
335     libxsmm_dnn_destroy_tensor_datalayout( libxsmm_layout );
336 
337     /* copy in data to LIBXSMM format */
338     /* we can also use the layout functions and set the data on our
339        own external to the library */
340     matrix_copy_NC_to_NCNC_bf16( naive_input_bf16,     input_libxsmm,     1, nImg, nIFm, bn, bc );
341     matrix_copy_NC_to_NCNC_bf16( naive_delinput_bf16,  delinput_libxsmm,  1, nImg, nIFm, bn, bc );
342     matrix_copy_NC_to_NCNC_bf16( naive_output_bf16,    output_libxsmm,    1, nImg, nOFm, bn, bk );
343     matrix_copy_NC_to_NCNC_bf16( naive_deloutput_bf16, deloutput_libxsmm, 1, nImg, nOFm, bn, bk );
344     matrix_copy_KC_to_KCCK_bf16( naive_filter_bf16,    filter_libxsmm      , nIFm, nOFm, bc, bk );
345     matrix_copy_KC_to_KCCK_bf16( naive_delfilter_bf16, delfilter_libxsmm   , nIFm, nOFm, bc, bk );
346     matrix_copy_NC_to_NCNC_bf16( naive_bias_bf16,    bias_libxsmm,    1, 1, nOFm, 1, nOFm );
347     matrix_copy_NC_to_NCNC_bf16( naive_delbias_bf16, delbias_libxsmm, 1, 1, nOFm, 1, nOFm );
348 
349     /* bind buffers and filter to handle */
350     CHKERR_LIBXSMM_DNN( libxsmm_dnn_fullyconnected_bind_tensor( libxsmm_handle, libxsmm_input,        LIBXSMM_DNN_REGULAR_INPUT ) );
351     CHKERR_LIBXSMM_DNN( libxsmm_dnn_fullyconnected_bind_tensor( libxsmm_handle, libxsmm_delinput,     LIBXSMM_DNN_GRADIENT_INPUT ) );
352     CHKERR_LIBXSMM_DNN( libxsmm_dnn_fullyconnected_bind_tensor( libxsmm_handle, libxsmm_output,       LIBXSMM_DNN_REGULAR_OUTPUT ) );
353     CHKERR_LIBXSMM_DNN( libxsmm_dnn_fullyconnected_bind_tensor( libxsmm_handle, libxsmm_deloutput,    LIBXSMM_DNN_GRADIENT_OUTPUT ) );
354     CHKERR_LIBXSMM_DNN( libxsmm_dnn_fullyconnected_bind_tensor( libxsmm_handle, libxsmm_filter,       LIBXSMM_DNN_REGULAR_FILTER ) );
355     CHKERR_LIBXSMM_DNN( libxsmm_dnn_fullyconnected_bind_tensor( libxsmm_handle, libxsmm_delfilter,    LIBXSMM_DNN_GRADIENT_FILTER ) );
356     CHKERR_LIBXSMM_DNN( libxsmm_dnn_fullyconnected_bind_tensor( libxsmm_handle, libxsmm_bias,         LIBXSMM_DNN_REGULAR_CHANNEL_BIAS ) );
357     CHKERR_LIBXSMM_DNN( libxsmm_dnn_fullyconnected_bind_tensor( libxsmm_handle, libxsmm_delbias,      LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS ) );
358     CHKERR_LIBXSMM_DNN( libxsmm_dnn_fullyconnected_bind_tensor( libxsmm_handle, libxsmm_relumask,     LIBXSMM_DNN_RELU_MASK ) );
359 
360     /* let's allocate and bind scratch */
361     scratch_size = libxsmm_dnn_fullyconnected_get_scratch_size( libxsmm_handle, &status );
362     CHKERR_LIBXSMM_DNN( status );
363     scratch = libxsmm_aligned_scratch( scratch_size, 2097152 );
364     CHKERR_LIBXSMM_DNN( libxsmm_dnn_fullyconnected_bind_scratch( libxsmm_handle, scratch ) );
365     /* set scratch to bogus to make sure that libxsmm takes care of zeroing internally */
366     init_buf( (float*)scratch, scratch_size/4, 0, 0 );
367 
368     if ((type == 'A' || type == 'F') && LIBXSMM_NEQ(0, check)) {
369       printf("##########################################\n");
370       printf("#   Correctness - FWD (custom-Storage)   #\n");
371       printf("##########################################\n");
372 
373 #if defined(_OPENMP)
374 #     pragma omp parallel
375 #endif
376       {
377 #if defined(_OPENMP)
378         const int tid = omp_get_thread_num();
379 #else
380         const int tid = 0;
381 #endif
382         CHKERR_LIBXSMM_DNN( libxsmm_dnn_fullyconnected_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_FWD, 0, tid ) );
383       }
384 
385       /* copy out data */
386       matrix_copy_NCNC_to_NC_bf16( output_libxsmm, naive_libxsmm_output_bf16, 1, nImg, nOFm, bn, bk );
387       libxsmm_convert_bf16_f32( naive_libxsmm_output_bf16, naive_libxsmm_output_f32, nImg*nOFm );
388 
389       /* compare */
390       libxsmm_matdiff(&norms_fwd, LIBXSMM_DATATYPE_F32, nImg*nOFm, 1, naive_output, naive_libxsmm_output_f32, 0, 0);
391       printf("L1 reference  : %.25g\n", norms_fwd.l1_ref);
392       printf("L1 test       : %.25g\n", norms_fwd.l1_tst);
393       printf("L2 abs.error  : %.24f\n", norms_fwd.l2_abs);
394       printf("L2 rel.error  : %.24f\n", norms_fwd.l2_rel);
395       printf("Linf abs.error: %.24f\n", norms_fwd.linf_abs);
396       printf("Linf rel.error: %.24f\n", norms_fwd.linf_rel);
397       printf("Check-norm    : %.24f\n", norms_fwd.normf_rel);
398       libxsmm_matdiff_reduce(&diff, &norms_fwd);
399     }
400     if ( (type == 'A' || type == 'B') && LIBXSMM_NEQ(0, check) ) {
401       printf("##########################################\n");
402       printf("#   Correctness - BWD (custom-Storage)   #\n");
403       printf("##########################################\n");
404 
405 #if defined(_OPENMP)
406 #     pragma omp parallel
407 #endif
408       {
409 #if defined(_OPENMP)
410         const int tid = omp_get_thread_num();
411 #else
412         const int tid = 0;
413 #endif
414         CHKERR_LIBXSMM_DNN( libxsmm_dnn_fullyconnected_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_BWD, 0, tid ) );
415       }
416 
417       /* copy out data */
418       matrix_copy_NCNC_to_NC_bf16( delinput_libxsmm, naive_libxsmm_delinput_bf16, 1, nImg, nIFm, bn, bc );
419       libxsmm_convert_bf16_f32( naive_libxsmm_delinput_bf16, naive_libxsmm_delinput_f32, nImg*nIFm );
420 
421       /* compare */
422       libxsmm_matdiff(&norms_bwd, LIBXSMM_DATATYPE_F32, nImg*nIFm, 1, naive_delinput, naive_libxsmm_delinput_f32, 0, 0);
423       printf("L1 reference  : %.25g\n", norms_bwd.l1_ref);
424       printf("L1 test       : %.25g\n", norms_bwd.l1_tst);
425       printf("L2 abs.error  : %.24f\n", norms_bwd.l2_abs);
426       printf("L2 rel.error  : %.24f\n", norms_bwd.l2_rel);
427       printf("Linf abs.error: %.24f\n", norms_bwd.linf_abs);
428       printf("Linf rel.error: %.24f\n", norms_bwd.linf_rel);
429       printf("Check-norm    : %.24f\n", norms_bwd.normf_rel);
430       libxsmm_matdiff_reduce(&diff, &norms_bwd);
431 
432       if ( (fuse_type == 1) || (fuse_type == 4) || (fuse_type == 5) ) {
433       /* copy out data */
434       matrix_copy_NCNC_to_NC_bf16( delbias_libxsmm, naive_libxsmm_delbias_bf16, 1, 1, nOFm, 1, nOFm );
435       libxsmm_convert_bf16_f32( naive_libxsmm_delbias_bf16, naive_libxsmm_delbias_f32, nOFm );
436       libxsmm_matdiff(&norms_bwd, LIBXSMM_DATATYPE_F32, nOFm, 1, naive_delbias, naive_libxsmm_delbias_f32, 0, 0);
437         printf("L1 reference  : %.25g\n", norms_bwd.l1_ref);
438         printf("L1 test       : %.25g\n", norms_bwd.l1_tst);
439         printf("L2 abs.error  : %.24f\n", norms_bwd.l2_abs);
440         printf("L2 rel.error  : %.24f\n", norms_bwd.l2_rel);
441         printf("Linf abs.error: %.24f\n", norms_bwd.linf_abs);
442         printf("Linf rel.error: %.24f\n", norms_bwd.linf_rel);
443         printf("Check-norm    : %.24f\n", norms_bwd.normf_rel);
444         libxsmm_matdiff_reduce(&diff, &norms_bwd);
445       }
446     }
447 
448     if ( (type == 'A' || type == 'U') && LIBXSMM_NEQ(0, check) ) {
449       printf("##########################################\n");
450       printf("#   Correctness - UPD (custom-Storage)   #\n");
451       printf("##########################################\n");
452 
453 #if defined(_OPENMP)
454 #     pragma omp parallel
455 #endif
456       {
457 #if defined(_OPENMP)
458         const int tid = omp_get_thread_num();
459 #else
460         const int tid = 0;
461 #endif
462         CHKERR_LIBXSMM_DNN( libxsmm_dnn_fullyconnected_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_UPD, 0, tid ) );
463       }
464 
465       /* copy out data */
466       matrix_copy_KCCK_to_KC_bf16( delfilter_libxsmm, naive_libxsmm_delfilter_bf16, nIFm, nOFm, bc, bk );
467       libxsmm_convert_bf16_f32( naive_libxsmm_delfilter_bf16, naive_libxsmm_delfilter_f32, nIFm*nOFm );
468 
469       /* compare */
470       libxsmm_matdiff(&norms_upd, LIBXSMM_DATATYPE_F32, nIFm*nOFm, 1, naive_delfilter, naive_libxsmm_delfilter_f32, 0, 0);
471       printf("L1 reference  : %.25g\n", norms_upd.l1_ref);
472       printf("L1 test       : %.25g\n", norms_upd.l1_tst);
473       printf("L2 abs.error  : %.24f\n", norms_upd.l2_abs);
474       printf("L2 rel.error  : %.24f\n", norms_upd.l2_rel);
475       printf("Linf abs.error: %.24f\n", norms_upd.linf_abs);
476       printf("Linf rel.error: %.24f\n", norms_upd.linf_rel);
477       printf("Check-norm    : %.24f\n", norms_upd.normf_rel);
478       libxsmm_matdiff_reduce(&diff, &norms_upd);
479     }
480 
481     if ( (type == 'A' || type == 'M') && LIBXSMM_NEQ(0, check) ) {
482       printf("##########################################\n");
483       printf("# Correctness - BWDUPD (custom-Storage)  #\n");
484       printf("##########################################\n");
485 
486 #if defined(_OPENMP)
487 #     pragma omp parallel
488 #endif
489       {
490 #if defined(_OPENMP)
491         const int tid = omp_get_thread_num();
492 #else
493         const int tid = 0;
494 #endif
495         CHKERR_LIBXSMM_DNN( libxsmm_dnn_fullyconnected_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_BWDUPD, 0, tid ) );
496       }
497 
498       /* copy out data */
499       matrix_copy_NCNC_to_NC_bf16( delinput_libxsmm, naive_libxsmm_delinput_bf16, 1, nImg, nIFm, bn, bc );
500       libxsmm_convert_bf16_f32( naive_libxsmm_delinput_bf16, naive_libxsmm_delinput_f32, nImg*nIFm );
501 
502       /* compare */
503       libxsmm_matdiff(&norms_bwd, LIBXSMM_DATATYPE_F32, nImg*nIFm, 1, naive_delinput, naive_libxsmm_delinput_f32, 0, 0);
504       printf("L1 reference  : %.25g\n", norms_bwd.l1_ref);
505       printf("L1 test       : %.25g\n", norms_bwd.l1_tst);
506       printf("L2 abs.error  : %.24f\n", norms_bwd.l2_abs);
507       printf("L2 rel.error  : %.24f\n", norms_bwd.l2_rel);
508       printf("Linf abs.error: %.24f\n", norms_bwd.linf_abs);
509       printf("Linf rel.error: %.24f\n", norms_bwd.linf_rel);
510       printf("Check-norm    : %.24f\n", norms_bwd.normf_rel);
511       libxsmm_matdiff_reduce(&diff, &norms_bwd);
512 
513       if ( (fuse_type == 1) || (fuse_type == 4) || (fuse_type == 5) ) {
514       /* copy out data */
515       matrix_copy_NCNC_to_NC_bf16( delbias_libxsmm, naive_libxsmm_delbias_bf16, 1, 1, nOFm, 1, nOFm );
516       libxsmm_convert_bf16_f32( naive_libxsmm_delbias_bf16, naive_libxsmm_delbias_f32, nOFm );
517       libxsmm_matdiff(&norms_bwd, LIBXSMM_DATATYPE_F32, nOFm, 1, naive_delbias, naive_libxsmm_delbias_f32, 0, 0);
518         printf("L1 reference  : %.25g\n", norms_bwd.l1_ref);
519         printf("L1 test       : %.25g\n", norms_bwd.l1_tst);
520         printf("L2 abs.error  : %.24f\n", norms_bwd.l2_abs);
521         printf("L2 rel.error  : %.24f\n", norms_bwd.l2_rel);
522         printf("Linf abs.error: %.24f\n", norms_bwd.linf_abs);
523         printf("Linf rel.error: %.24f\n", norms_bwd.linf_rel);
524         printf("Check-norm    : %.24f\n", norms_bwd.normf_rel);
525         libxsmm_matdiff_reduce(&diff, &norms_bwd);
526       }
527 
528       /* copy out data */
529       matrix_copy_KCCK_to_KC_bf16( delfilter_libxsmm, naive_libxsmm_delfilter_bf16, nIFm, nOFm, bc, bk );
530       libxsmm_convert_bf16_f32( naive_libxsmm_delfilter_bf16, naive_libxsmm_delfilter_f32, nIFm*nOFm );
531 
532       /* compare */
533       libxsmm_matdiff(&norms_upd, LIBXSMM_DATATYPE_F32, nIFm*nOFm, 1, naive_delfilter, naive_libxsmm_delfilter_f32, 0, 0);
534       printf("L1 reference  : %.25g\n", norms_upd.l1_ref);
535       printf("L1 test       : %.25g\n", norms_upd.l1_tst);
536       printf("L2 abs.error  : %.24f\n", norms_upd.l2_abs);
537       printf("L2 rel.error  : %.24f\n", norms_upd.l2_rel);
538       printf("Linf abs.error: %.24f\n", norms_upd.linf_abs);
539       printf("Linf rel.error: %.24f\n", norms_upd.linf_rel);
540       printf("Check-norm    : %.24f\n", norms_upd.normf_rel);
541       libxsmm_matdiff_reduce(&diff, &norms_upd);
542      }
543 
544     if (type == 'A' || type == 'F') {
545       printf("##########################################\n");
546       printf("#   Performance - FWD (custom-Storage)   #\n");
547       printf("##########################################\n");
548       l_start = libxsmm_timer_tick();
549 #if defined(_OPENMP)
550 #     pragma omp parallel private(i)
551 #endif
552       {
553 #if defined(_OPENMP)
554         const int tid = omp_get_thread_num();
555 #else
556         const int tid = 0;
557 #endif
558         for (i = 0; i < iters; ++i) {
559           libxsmm_dnn_fullyconnected_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_FWD, 0, tid );
560         }
561       }
562       l_end = libxsmm_timer_tick();
563       l_total = libxsmm_timer_duration(l_start, l_end);
564 
565       gflop = (2.0*(double)nImg*(double)nIFm*(double)nOFm*(double)iters) / (1000*1000*1000);
566 
567       printf("GFLOP  = %.5g\n", gflop/(double)iters);
568       printf("fp time = %.5g\n", ((double)(l_total/iters)));
569       printf("GFLOPS  = %.5g\n", gflop/l_total);
570 
571       char tune_string_fwd[1000];
572       sprintf(tune_string_fwd,"threads=%d_2D=%d_rows=%d_cols=%d_BN=%d_BK=%d_BC=%d_BFACCUM=%d",nThreads, fwd_2d_blocking, fwd_row_teams, fwd_column_teams, bn, bk, bc, fwd_bf);
573 
574       printf("PERFDUMP,%s,FP,%s,%i,%i,%i,%i,%.5g,%.5g,%f,%f,%f,%f,%f,%f,%f\n",tune_string_fwd, LIBXSMM_VERSION, nThreads, nImg, nIFm,
575           nOFm, ((double)(l_total/iters)), gflop/l_total, norms_fwd.l1_ref, norms_fwd.l1_tst,
576           norms_fwd.l2_abs, norms_fwd.l2_rel, norms_fwd.linf_abs, norms_fwd.linf_rel, norms_fwd.normf_rel);
577     }
578     if (type == 'A' || type == 'B') {
579       printf("##########################################\n");
580       printf("#   Performance - BWD (custom-Storage)   #\n");
581       printf("##########################################\n");
582       l_start = libxsmm_timer_tick();
583 #if defined(_OPENMP)
584 #     pragma omp parallel private(i)
585 #endif
586       {
587 #if defined(_OPENMP)
588         const int tid = omp_get_thread_num();
589 #else
590         const int tid = 0;
591 #endif
592         for (i = 0; i < iters; ++i) {
593           libxsmm_dnn_fullyconnected_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_BWD, 0, tid );
594         }
595       }
596       l_end = libxsmm_timer_tick();
597       l_total = libxsmm_timer_duration(l_start, l_end);
598 
599       gflop = (2.0*(double)nImg*(double)nIFm*(double)nOFm*(double)iters) / (1000*1000*1000);
600 
601       printf("GFLOP  = %.5g\n", gflop/(double)iters);
602       printf("fp time = %.5g\n", ((double)(l_total/iters)));
603       printf("GFLOPS  = %.5g\n", gflop/l_total);
604 
605       char tune_string_bwd[1000];
606       sprintf(tune_string_bwd,"threads=%d_2D=%d_rows=%d_cols=%d_BN=%d_BK=%d_BC=%d_BFACCUM=%d",nThreads, bwd_2d_blocking, bwd_row_teams, bwd_column_teams, bn, bk, bc, bwd_bf);
607       printf("PERFDUMP,%s,BP,%s,%i,%i,%i,%i,%.5g,%.5g,%f,%f,%f,%f,%f,%f,%f\n", tune_string_bwd , LIBXSMM_VERSION, nThreads, nImg, nIFm,
608           nOFm, ((double)(l_total/iters)), gflop/l_total, norms_bwd.l1_ref, norms_bwd.l1_tst,
609           norms_bwd.l2_abs, norms_bwd.l2_rel, norms_bwd.linf_abs, norms_bwd.linf_rel, norms_bwd.normf_rel);
610     }
611     if (type == 'A' || type == 'U') {
612       printf("##########################################\n");
613       printf("#   Performance - UPD (custom-Storage)   #\n");
614       printf("##########################################\n");
615       l_start = libxsmm_timer_tick();
616 #if defined(_OPENMP)
617 #     pragma omp parallel private(i)
618 #endif
619       {
620 #if defined(_OPENMP)
621         const int tid = omp_get_thread_num();
622 #else
623         const int tid = 0;
624 #endif
625         for (i = 0; i < iters; ++i) {
626           libxsmm_dnn_fullyconnected_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_UPD, 0, tid );
627         }
628       }
629       l_end = libxsmm_timer_tick();
630       l_total = libxsmm_timer_duration(l_start, l_end);
631 
632       gflop = (2.0*(double)nImg*(double)nIFm*(double)nOFm*(double)iters) / (1000*1000*1000);
633 
634       printf("GFLOP  = %.5g\n", gflop/(double)iters);
635       printf("fp time = %.5g\n", ((double)(l_total/iters)));
636       printf("GFLOPS  = %.5g\n", gflop/l_total);
637 
638       char tune_string_upd[1000];
639       sprintf(tune_string_upd,"threads=%d_2D=%d_rows=%d_cols=%d_BN=%d_BK=%d_BC=%d_BFACCUM=%d_IFMSUBTASK=%d_OFMSUBTASK=%d",nThreads, upd_2d_blocking, upd_row_teams, upd_column_teams, bn, bk, bc, upd_bf, ifm_subtasks, ofm_subtasks);
640       printf("PERFDUMP,%s,UP,%s,%i,%i,%i,%i,%.5g,%.5g,%f,%f,%f,%f,%f,%f,%f\n", tune_string_upd , LIBXSMM_VERSION, nThreads, nImg, nIFm,
641           nOFm, ((double)(l_total/iters)), gflop/l_total, norms_upd.l1_ref, norms_upd.l1_tst,
642           norms_upd.l2_abs, norms_upd.l2_rel, norms_upd.linf_abs, norms_upd.linf_rel, norms_upd.normf_rel);
643     }
644 
645     if (type == 'A' || type == 'M') {
646       printf("##########################################\n");
647       printf("# Performance - BWDUPD (custom-Storage)  #\n");
648       printf("##########################################\n");
649       l_start = libxsmm_timer_tick();
650 #if defined(_OPENMP)
651 #     pragma omp parallel private(i)
652 #endif
653       {
654 #if defined(_OPENMP)
655         const int tid = omp_get_thread_num();
656 #else
657         const int tid = 0;
658 #endif
659         for (i = 0; i < iters; ++i) {
660           libxsmm_dnn_fullyconnected_execute_st( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_BWDUPD, 0, tid );
661         }
662       }
663       l_end = libxsmm_timer_tick();
664       l_total = libxsmm_timer_duration(l_start, l_end);
665 
666       gflop = (4.0*(double)nImg*(double)nIFm*(double)nOFm*(double)iters) / (1000*1000*1000);
667 
668       printf("GFLOP  = %.5g\n", gflop/(double)iters);
669       printf("fp time = %.5g\n", ((double)(l_total/iters)));
670       printf("GFLOPS  = %.5g\n", gflop/l_total);
671 
672       printf("PERFDUMP,UP,%s,%i,%i,%i,%i,%.5g,%.5g,%f,%f,%f,%f,%f,%f,%f\n", LIBXSMM_VERSION, nThreads, nImg, nIFm,
673           nOFm, ((double)(l_total/iters)), gflop/l_total, norms_upd.l1_ref, norms_upd.l1_tst,
674           norms_upd.l2_abs, norms_upd.l2_rel, norms_upd.linf_abs, norms_upd.linf_rel, norms_upd.normf_rel);
675     }
676 
677     /* clean-up */
678     CHKERR_LIBXSMM_DNN( libxsmm_dnn_fullyconnected_release_scratch( libxsmm_handle ) );
679     libxsmm_free(scratch);
680     CHKERR_LIBXSMM_DNN( libxsmm_dnn_fullyconnected_release_tensor( libxsmm_handle, LIBXSMM_DNN_REGULAR_INPUT ) );
681     CHKERR_LIBXSMM_DNN( libxsmm_dnn_fullyconnected_release_tensor( libxsmm_handle, LIBXSMM_DNN_GRADIENT_INPUT ) );
682     CHKERR_LIBXSMM_DNN( libxsmm_dnn_fullyconnected_release_tensor( libxsmm_handle, LIBXSMM_DNN_REGULAR_OUTPUT ) );
683     CHKERR_LIBXSMM_DNN( libxsmm_dnn_fullyconnected_release_tensor( libxsmm_handle, LIBXSMM_DNN_GRADIENT_OUTPUT ) );
684     CHKERR_LIBXSMM_DNN( libxsmm_dnn_fullyconnected_release_tensor( libxsmm_handle, LIBXSMM_DNN_REGULAR_FILTER ) );
685     CHKERR_LIBXSMM_DNN( libxsmm_dnn_fullyconnected_release_tensor( libxsmm_handle, LIBXSMM_DNN_GRADIENT_FILTER ) );
686     CHKERR_LIBXSMM_DNN( libxsmm_dnn_fullyconnected_release_tensor( libxsmm_handle, LIBXSMM_DNN_REGULAR_CHANNEL_BIAS ) );
687     CHKERR_LIBXSMM_DNN( libxsmm_dnn_fullyconnected_release_tensor( libxsmm_handle, LIBXSMM_DNN_GRADIENT_CHANNEL_BIAS ) );
688     CHKERR_LIBXSMM_DNN( libxsmm_dnn_fullyconnected_release_tensor( libxsmm_handle, LIBXSMM_DNN_RELU_MASK ) );
689 
690 
691     CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_input ) );
692     CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_delinput ) );
693     CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_output ) );
694     CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_deloutput ) );
695     CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_filter ) );
696     CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_delfilter ) );
697     CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_bias ) );
698     CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_delbias ) );
699     CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_tensor( libxsmm_relumask ) );
700 
701     CHKERR_LIBXSMM_DNN( libxsmm_dnn_destroy_fullyconnected( libxsmm_handle ) );
702   }
703 
704   /* deallocate data */
705   libxsmm_free(naive_input);
706   libxsmm_free(naive_output);
707   libxsmm_free(naive_delinput);
708   libxsmm_free(naive_deloutput);
709   libxsmm_free(naive_filter);
710   libxsmm_free(naive_delfilter);
711   libxsmm_free(naive_input_bf16);
712   libxsmm_free(naive_delinput_bf16);
713   libxsmm_free(naive_output_bf16);
714   libxsmm_free(naive_deloutput_bf16);
715   libxsmm_free(naive_filter_bf16);
716   libxsmm_free(naive_delfilter_bf16);
717   libxsmm_free(naive_libxsmm_output_bf16);
718   libxsmm_free(naive_libxsmm_delinput_bf16);
719   libxsmm_free(naive_libxsmm_delfilter_bf16);
720   libxsmm_free(naive_libxsmm_output_f32);
721   libxsmm_free(naive_libxsmm_delinput_f32);
722   libxsmm_free(naive_libxsmm_delfilter_f32);
723   libxsmm_free(input_libxsmm);
724   libxsmm_free(output_libxsmm);
725   libxsmm_free(delinput_libxsmm);
726   libxsmm_free(deloutput_libxsmm);
727   libxsmm_free(filter_libxsmm);
728   libxsmm_free(delfilter_libxsmm);
729   libxsmm_free(naive_bias);
730   libxsmm_free(naive_delbias);
731   libxsmm_free(naive_bias_bf16);
732   libxsmm_free(naive_delbias_bf16);
733   libxsmm_free(naive_libxsmm_delbias_bf16);
734   libxsmm_free(naive_libxsmm_delbias_f32);
735   libxsmm_free(relumask_libxsmm);
736   libxsmm_free(bias_libxsmm);
737   libxsmm_free(delbias_libxsmm);
738 
739   { const char *const env_check_scale = getenv("CHECK_SCALE");
740     const double check_scale = LIBXSMM_ABS(0 == env_check_scale ? 1.0 : atof(env_check_scale));
741     if (LIBXSMM_NEQ(0, check) && (check < 100.0 * check_scale * diff.normf_rel) && (global_status == LIBXSMM_DNN_SUCCESS)) {
742       fprintf(stderr, "FAILED with an error of %f%%!\n", 100.0 * diff.normf_rel);
743       exit(EXIT_FAILURE);
744     }
745   }
746 
747   /* some empty lines at the end */
748   printf("\n\n\n");
749 
750   return global_status;
751 }
752 
753