1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Hans Pabst (Intel Corp.)
10 ******************************************************************************/
11 #include <libxsmm.h>
12 #include <libxsmm_intrinsics_x86.h>
13 
14 #if !defined(ITYPE)
15 # define ITYPE double
16 #endif
17 #if !defined(OTYPE)
18 # define OTYPE ITYPE
19 #endif
20 #if !defined(CHECK_FPE)
21 # define CHECK_FPE
22 #endif
23 #if !defined(GEMM_GOLD)
24 # define GEMM_GOLD LIBXSMM_GEMM_SYMBOL
25 #endif
26 #if !defined(GEMM)
27 # define GEMM LIBXSMM_XGEMM_SYMBOL
28 #endif
29 #if !defined(GEMM2)
30 # define GEMM2 LIBXSMM_YGEMM_SYMBOL
31 #endif
32 #if !defined(SMM)
33 # define SMM LIBXSMM_XGEMM_SYMBOL
34 #endif
35 #if !defined(GEMM_NO_BYPASS)
36 # define SMM_NO_BYPASS(FLAGS, ALPHA, BETA) LIBXSMM_GEMM_NO_BYPASS(FLAGS, ALPHA, BETA)
37 #endif
38 
39 
40 #if (LIBXSMM_EQUAL(ITYPE, float) || LIBXSMM_EQUAL(ITYPE, double)) \
41   && !defined(MKL_DIRECT_CALL_SEQ) && !defined(MKL_DIRECT_CALL)
LIBXSMM_BLAS_SYMBOL_DECL(ITYPE,gemm)42 LIBXSMM_BLAS_SYMBOL_DECL(ITYPE, gemm)
43 #endif
44 
45 
46 int main(void)
47 {
48   /* test#:                 1  2  3  4  5  6  7  8  9 10 11 12    13   14     15  16  17  18  19     20   21   22   23   24   25   26   27   28   29  30  31  32  33    34    35  36 37 */
49   /* index:                 0  1  2  3  4  5  6  7  8  9 10 11    12   13     14  15  16  17  18     19   20   21   22   23   24   25   26   27   28  29  30  31  32    33    34  35 36 */
50   libxsmm_blasint m[]   = { 0, 1, 0, 0, 1, 1, 2, 3, 3, 1, 4, 8,   64,  64,    16, 80, 80, 80, 80,    16, 260, 260, 260, 260, 350, 350, 350, 350, 350,  5, 10, 12, 20,   32,    9, 13, 5 };
51   libxsmm_blasint n[]   = { 0, 0, 1, 0, 1, 2, 2, 3, 1, 3, 1, 1,    8, 239, 13824,  1,  3,  5,  7, 65792,   1,   3,   5,   7,  16,   1,  25,   4,   9, 13,  1, 10,  6,   33,    9, 13, 5 };
52   libxsmm_blasint k[]   = { 0, 0, 0, 1, 1, 2, 2, 3, 2, 2, 4, 0,   64,  64,    16,  1,  3,  6, 10,    16,   1,   3,   6,  10,  20,   1,  35,   4,  10, 70,  1, 12,  6,  192, 1742, 13, 5 };
53   libxsmm_blasint lda[] = { 1, 1, 1, 1, 1, 1, 2, 3, 3, 1, 4, 8,   64,  64,    16, 80, 80, 80, 80,    16, 260, 260, 260, 260, 350, 350, 350, 350, 350,  5, 22, 22, 22,   32,    9, 13, 5 };
54   libxsmm_blasint ldb[] = { 1, 1, 1, 1, 1, 2, 2, 3, 2, 2, 4, 8, 9216, 240,    16,  1,  3,  5,  5,    16,   1,   3,   5,   7,  35,  35,  35,  35,  35, 70,  1, 20,  8, 2048, 1742, 13, 5 };
55   libxsmm_blasint ldc[] = { 1, 1, 1, 1, 1, 1, 2, 3, 3, 1, 4, 8, 4096, 240,    16, 80, 80, 80, 80,    16, 260, 260, 260, 260, 350, 350, 350, 350, 350,  5, 22, 12, 20, 2048,    9, 13, 5 };
56   OTYPE alpha[]         = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,    1,   1,     1,  1,  1,  1,  1,     1,   1,   1,   1,   1,   1,   1,   1,   1,   1,  1,  1,  1,  1,    1,    1,  1, 1 };
57   OTYPE beta[]          = { 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0,    0,   1,     0,  0,  0,  0,  0,     1,   0,   0,   0,   0,   0,   0,   1,   0,   0,  1,  0,  1,  0,    1,    0,  1, 1 };
58 #if (!defined(__BLAS) || (0 != __BLAS)) && defined(GEMM_GOLD)
59   char transa[] = "NNNTT";
60 #else
61   char transa[] = "NN";
62 #endif
63   char transb[] = "NNTNT";
64   const int begin = 0, end = sizeof(m) / sizeof(*m), i0 = 0, i1 = sizeof(transa) - 1;
65   libxsmm_blasint max_size_a = 0, max_size_b = 0, max_size_c = 0, block = 1;
66 #if defined(_DEBUG)
67   libxsmm_matdiff_info diff;
68 #endif
69   ITYPE *a = NULL, *b = NULL;
70   OTYPE *c = NULL;
71 #if defined(GEMM)
72   OTYPE *d = NULL;
73 #endif
74 #if (!defined(__BLAS) || (0 != __BLAS)) && defined(GEMM_GOLD)
75   OTYPE *gold = NULL;
76 #endif
77   int result = EXIT_SUCCESS, test, i;
78 #if defined(CHECK_FPE) && defined(_MM_GET_EXCEPTION_MASK)
79   const unsigned int fpemask = _MM_GET_EXCEPTION_MASK(); /* backup FPE mask */
80   const unsigned int fpcheck = _MM_MASK_INVALID | _MM_MASK_OVERFLOW;
81   unsigned int fpstate = 0;
82   _MM_SET_EXCEPTION_MASK(fpemask & ~fpcheck);
83 #endif
84   LIBXSMM_BLAS_INIT
85   for (test = begin; test < end; ++test) {
86     m[test] = LIBXSMM_UP(m[test], block);
87     n[test] = LIBXSMM_UP(n[test], block);
88     k[test] = LIBXSMM_UP(k[test], block);
89     lda[test] = LIBXSMM_MAX(lda[test], m[test]);
90     ldb[test] = LIBXSMM_MAX(ldb[test], k[test]);
91     ldc[test] = LIBXSMM_MAX(ldc[test], m[test]);
92   }
93   for (test = begin; test < end; ++test) {
94     const libxsmm_blasint size_a = lda[test] * k[test], size_b = ldb[test] * n[test], size_c = ldc[test] * n[test];
95     LIBXSMM_ASSERT(m[test] <= lda[test] && k[test] <= ldb[test] && m[test] <= ldc[test]);
96     max_size_a = LIBXSMM_MAX(max_size_a, size_a);
97     max_size_b = LIBXSMM_MAX(max_size_b, size_b);
98     max_size_c = LIBXSMM_MAX(max_size_c, size_c);
99   }
100   a = (ITYPE*)libxsmm_malloc((size_t)(max_size_a * sizeof(ITYPE)));
101   b = (ITYPE*)libxsmm_malloc((size_t)(max_size_b * sizeof(ITYPE)));
102   c = (OTYPE*)libxsmm_malloc((size_t)(max_size_c * sizeof(OTYPE)));
103 #if defined(GEMM)
104   d = (OTYPE*)libxsmm_malloc((size_t)(max_size_c * sizeof(OTYPE)));
105   LIBXSMM_ASSERT(NULL != d);
106 #endif
107 #if (!defined(__BLAS) || (0 != __BLAS)) && defined(GEMM_GOLD)
108   gold = (OTYPE*)libxsmm_malloc((size_t)(max_size_c * sizeof(OTYPE)));
109   LIBXSMM_ASSERT(NULL != gold);
110 #endif
111   LIBXSMM_ASSERT(NULL != a && NULL != b && NULL != c);
112   LIBXSMM_MATINIT(ITYPE, 42, a, max_size_a, 1, max_size_a, 1.0);
113   LIBXSMM_MATINIT(ITYPE, 24, b, max_size_b, 1, max_size_b, 1.0);
114 #if defined(_DEBUG)
115   libxsmm_matdiff_clear(&diff);
116 #endif
117   for (test = begin; test < end && EXIT_SUCCESS == result; ++test) {
118     for (i = i0; i < i1 && EXIT_SUCCESS == result; ++i) {
119       libxsmm_blasint mi = m[test], ni = n[test], ki = k[test];
120       const int flags = LIBXSMM_GEMM_FLAGS(transa[i], transb[i]);
121       const int smm = SMM_NO_BYPASS(flags, alpha[test], beta[test]);
122 #if defined(CHECK_FPE) && defined(_MM_GET_EXCEPTION_MASK)
123       _MM_SET_EXCEPTION_STATE(0);
124 #endif
125       if ('N' != transa[i] && 'N' == transb[i]) { /* TN */
126         mi = ki = LIBXSMM_MIN(mi, ki);
127       }
128       else if ('N' == transa[i] && 'N' != transb[i]) { /* NT */
129         ki = ni = LIBXSMM_MIN(ki, ni);
130       }
131       else if ('N' != transa[i] && 'N' != transb[i]) { /* TT */
132         const libxsmm_blasint ti = LIBXSMM_MIN(mi, ni);
133         mi = ni = ki = LIBXSMM_MIN(ti, ki);
134       }
135       if (LIBXSMM_FEQ(0, beta[test])) {
136 #if (!defined(__BLAS) || (0 != __BLAS)) && defined(GEMM_GOLD)
137         memset(gold, -1, (size_t)(sizeof(OTYPE) * max_size_c));
138 #endif
139         memset(c, -1, (size_t)(sizeof(OTYPE) * max_size_c));
140 #if defined(GEMM)
141         memset(d, -1, (size_t)(sizeof(OTYPE) * max_size_c));
142 #endif
143       }
144       else {
145 #if (!defined(__BLAS) || (0 != __BLAS)) && defined(GEMM_GOLD)
146         memset(gold, 0, (size_t)(sizeof(OTYPE) * max_size_c));
147 #endif
148         memset(c, 0, (size_t)(sizeof(OTYPE) * max_size_c));
149 #if defined(GEMM)
150         memset(d, 0, (size_t)(sizeof(OTYPE) * max_size_c));
151 #endif
152       }
153       if (0 != smm) {
154         SMM(ITYPE)(transa + i, transb + i, &mi, &ni, &ki,
155           alpha + test, a, lda + test, b, ldb + test, beta + test, c, ldc + test);
156       }
157 #if defined(GEMM)
158       else {
159         GEMM(ITYPE)(transa + i, transb + i, &mi, &ni, &ki,
160           alpha + test, a, lda + test, b, ldb + test, beta + test, c, ldc + test);
161       }
162 # if defined(GEMM2)
163       GEMM2(ITYPE)(transa + i, transb + i, &mi, &ni, &ki,
164         alpha + test, a, lda + test, b, ldb + test, beta + test, d, ldc + test);
165 # else
166       GEMM(ITYPE)(transa + i, transb + i, &mi, &ni, &ki,
167         alpha + test, a, lda + test, b, ldb + test, beta + test, d, ldc + test);
168 # endif
169 #endif
170 #if (0 != LIBXSMM_JIT)
171       if (0 != smm) { /* dispatch kernel and check that it is available */
172         const LIBXSMM_MMFUNCTION_TYPE(ITYPE) kernel = LIBXSMM_MMDISPATCH_SYMBOL(ITYPE)(mi, ni, ki,
173           lda + test, ldb + test, ldc + test, alpha + test, beta + test, &flags, NULL/*prefetch*/);
174         if (NULL == kernel) {
175 # if defined(_DEBUG)
176           fprintf(stderr, "\nERROR: kernel %i.%i not generated!\n\t", test + 1, i + 1);
177           libxsmm_gemm_print(stderr, LIBXSMM_GEMM_PRECISION(ITYPE), transa + i, transb + i, &mi, &ni, &ki,
178             alpha + test, NULL/*a*/, lda + test, NULL/*b*/, ldb + test, beta + test, NULL/*c*/, ldc + test);
179           fprintf(stderr, "\n");
180 # endif
181           result = EXIT_FAILURE;
182           break;
183         }
184       }
185 #endif
186 #if defined(CHECK_FPE) && defined(_MM_GET_EXCEPTION_MASK)
187       fpstate = _MM_GET_EXCEPTION_STATE() & fpcheck;
188       result = (0 == fpstate ? EXIT_SUCCESS : EXIT_FAILURE);
189       if (EXIT_SUCCESS != result) {
190 # if defined(_DEBUG)
191         fprintf(stderr, "FPE(%i.%i): state=0x%08x -> invalid=%s overflow=%s\n", test + 1, i + 1, fpstate,
192           0 != (_MM_MASK_INVALID  & fpstate) ? "true" : "false",
193           0 != (_MM_MASK_OVERFLOW & fpstate) ? "true" : "false");
194 # endif
195       }
196 # if (!defined(__BLAS) || (0 != __BLAS)) && defined(GEMM_GOLD)
197       else
198 # endif
199 #endif
200 #if (!defined(__BLAS) || (0 != __BLAS)) && defined(GEMM_GOLD)
201 # if !defined(GEMM)
202       if (0 != smm)
203 # endif
204       {
205 # if defined(GEMM_GOLD)
206         libxsmm_matdiff_info diff_test;
207         GEMM_GOLD(ITYPE)(transa + i, transb + i, &mi, &ni, &ki,
208           alpha + test, a, lda + test, b, ldb + test, beta + test, gold, ldc + test);
209 
210         result = libxsmm_matdiff(&diff_test, LIBXSMM_DATATYPE(OTYPE), mi, ni, gold, c, ldc + test, ldc + test);
211         if (EXIT_SUCCESS == result) {
212 #   if defined(_DEBUG)
213           libxsmm_matdiff_reduce(&diff, &diff_test);
214 #   endif
215           if (1.0 < (1000.0 * diff_test.normf_rel)) {
216 #   if defined(_DEBUG)
217             if (0 != smm) {
218               fprintf(stderr, "\nERROR: SMM test %i.%i failed!\n\t", test + 1, i + 1);
219             }
220             else {
221               fprintf(stderr, "\nERROR: test %i.%i failed!\n\t", test + 1, i + 1);
222             }
223             libxsmm_gemm_print(stderr, LIBXSMM_GEMM_PRECISION(ITYPE), transa + i, transb + i, &mi, &ni, &ki,
224               alpha + test, NULL/*a*/, lda + test, NULL/*b*/, ldb + test, beta + test, NULL/*c*/, ldc + test);
225             fprintf(stderr, "\n");
226 #   endif
227             result = EXIT_FAILURE;
228           }
229 #   if defined(GEMM)
230           else {
231             result = libxsmm_matdiff(&diff_test, LIBXSMM_DATATYPE(OTYPE), mi, ni, gold, d, ldc + test, ldc + test);
232             if (EXIT_SUCCESS == result) {
233 #     if defined(_DEBUG)
234               libxsmm_matdiff_reduce(&diff, &diff_test);
235 #     endif
236               if (1.0 < (1000.0 * diff_test.normf_rel)) {
237 #     if defined(_DEBUG)
238                 fprintf(stderr, "\nERROR: test %i.%i failed!\n\t", test + 1, i + 1);
239                 libxsmm_gemm_print(stderr, LIBXSMM_GEMM_PRECISION(ITYPE), transa + i, transb + i, &mi, &ni, &ki,
240                   alpha + test, NULL/*a*/, lda + test, NULL/*b*/, ldb + test, beta + test, NULL/*c*/, ldc + test);
241                 fprintf(stderr, "\n");
242 #     endif
243                 result = EXIT_FAILURE;
244               }
245             }
246           }
247 #   endif
248         }
249 # endif
250       }
251 # if defined(GEMM_GOLD)
252       /* avoid drift between Gold and test-results */
253       memcpy(c, gold, (size_t)(sizeof(OTYPE) * max_size_c));
254 #   if defined(GEMM)
255       memcpy(d, gold, (size_t)(sizeof(OTYPE) * max_size_c));
256 #   endif
257 # endif
258 #elif defined(_DEBUG)
259       fprintf(stderr, "Warning: skipped the test due to missing BLAS support!\n");
260 #endif
261     }
262   }
263 
264 #if defined(_DEBUG)
265   fprintf(stderr, "Diff: L2abs=%f Linf=%f\n", diff.l2_abs, diff.linf_abs);
266 #endif
267 #if defined(CHECK_FPE) && defined(_MM_GET_EXCEPTION_MASK)
268   _MM_SET_EXCEPTION_MASK(fpemask); /* restore FPE mask */
269   _MM_SET_EXCEPTION_STATE(0); /* clear FPE state */
270 #endif
271   libxsmm_free(a);
272   libxsmm_free(b);
273   libxsmm_free(c);
274 #if defined(GEMM)
275   libxsmm_free(d);
276 #endif
277 #if (!defined(__BLAS) || (0 != __BLAS)) && defined(GEMM_GOLD)
278   libxsmm_free(gold);
279 #endif
280   return result;
281 }
282 
283