1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved. *
3 * This file is part of the LIBXSMM library. *
4 * *
5 * For information on the license, see the LICENSE file. *
6 * Further information: https://github.com/hfp/libxsmm/ *
7 * SPDX-License-Identifier: BSD-3-Clause *
8 ******************************************************************************/
9 /* Hans Pabst (Intel Corp.)
10 ******************************************************************************/
11 #include <libxsmm.h>
12 #include <libxsmm_intrinsics_x86.h>
13
14 #if !defined(ITYPE)
15 # define ITYPE double
16 #endif
17 #if !defined(OTYPE)
18 # define OTYPE ITYPE
19 #endif
20 #if !defined(CHECK_FPE)
21 # define CHECK_FPE
22 #endif
23 #if !defined(GEMM_GOLD)
24 # define GEMM_GOLD LIBXSMM_GEMM_SYMBOL
25 #endif
26 #if !defined(GEMM)
27 # define GEMM LIBXSMM_XGEMM_SYMBOL
28 #endif
29 #if !defined(GEMM2)
30 # define GEMM2 LIBXSMM_YGEMM_SYMBOL
31 #endif
32 #if !defined(SMM)
33 # define SMM LIBXSMM_XGEMM_SYMBOL
34 #endif
35 #if !defined(GEMM_NO_BYPASS)
36 # define SMM_NO_BYPASS(FLAGS, ALPHA, BETA) LIBXSMM_GEMM_NO_BYPASS(FLAGS, ALPHA, BETA)
37 #endif
38
39
40 #if (LIBXSMM_EQUAL(ITYPE, float) || LIBXSMM_EQUAL(ITYPE, double)) \
41 && !defined(MKL_DIRECT_CALL_SEQ) && !defined(MKL_DIRECT_CALL)
LIBXSMM_BLAS_SYMBOL_DECL(ITYPE,gemm)42 LIBXSMM_BLAS_SYMBOL_DECL(ITYPE, gemm)
43 #endif
44
45
46 int main(void)
47 {
48 /* test#: 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 */
49 /* index: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 */
50 libxsmm_blasint m[] = { 0, 1, 0, 0, 1, 1, 2, 3, 3, 1, 4, 8, 64, 64, 16, 80, 80, 80, 80, 16, 260, 260, 260, 260, 350, 350, 350, 350, 350, 5, 10, 12, 20, 32, 9, 13, 5 };
51 libxsmm_blasint n[] = { 0, 0, 1, 0, 1, 2, 2, 3, 1, 3, 1, 1, 8, 239, 13824, 1, 3, 5, 7, 65792, 1, 3, 5, 7, 16, 1, 25, 4, 9, 13, 1, 10, 6, 33, 9, 13, 5 };
52 libxsmm_blasint k[] = { 0, 0, 0, 1, 1, 2, 2, 3, 2, 2, 4, 0, 64, 64, 16, 1, 3, 6, 10, 16, 1, 3, 6, 10, 20, 1, 35, 4, 10, 70, 1, 12, 6, 192, 1742, 13, 5 };
53 libxsmm_blasint lda[] = { 1, 1, 1, 1, 1, 1, 2, 3, 3, 1, 4, 8, 64, 64, 16, 80, 80, 80, 80, 16, 260, 260, 260, 260, 350, 350, 350, 350, 350, 5, 22, 22, 22, 32, 9, 13, 5 };
54 libxsmm_blasint ldb[] = { 1, 1, 1, 1, 1, 2, 2, 3, 2, 2, 4, 8, 9216, 240, 16, 1, 3, 5, 5, 16, 1, 3, 5, 7, 35, 35, 35, 35, 35, 70, 1, 20, 8, 2048, 1742, 13, 5 };
55 libxsmm_blasint ldc[] = { 1, 1, 1, 1, 1, 1, 2, 3, 3, 1, 4, 8, 4096, 240, 16, 80, 80, 80, 80, 16, 260, 260, 260, 260, 350, 350, 350, 350, 350, 5, 22, 12, 20, 2048, 9, 13, 5 };
56 OTYPE alpha[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 };
57 OTYPE beta[] = { 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1 };
58 #if (!defined(__BLAS) || (0 != __BLAS)) && defined(GEMM_GOLD)
59 char transa[] = "NNNTT";
60 #else
61 char transa[] = "NN";
62 #endif
63 char transb[] = "NNTNT";
64 const int begin = 0, end = sizeof(m) / sizeof(*m), i0 = 0, i1 = sizeof(transa) - 1;
65 libxsmm_blasint max_size_a = 0, max_size_b = 0, max_size_c = 0, block = 1;
66 #if defined(_DEBUG)
67 libxsmm_matdiff_info diff;
68 #endif
69 ITYPE *a = NULL, *b = NULL;
70 OTYPE *c = NULL;
71 #if defined(GEMM)
72 OTYPE *d = NULL;
73 #endif
74 #if (!defined(__BLAS) || (0 != __BLAS)) && defined(GEMM_GOLD)
75 OTYPE *gold = NULL;
76 #endif
77 int result = EXIT_SUCCESS, test, i;
78 #if defined(CHECK_FPE) && defined(_MM_GET_EXCEPTION_MASK)
79 const unsigned int fpemask = _MM_GET_EXCEPTION_MASK(); /* backup FPE mask */
80 const unsigned int fpcheck = _MM_MASK_INVALID | _MM_MASK_OVERFLOW;
81 unsigned int fpstate = 0;
82 _MM_SET_EXCEPTION_MASK(fpemask & ~fpcheck);
83 #endif
84 LIBXSMM_BLAS_INIT
85 for (test = begin; test < end; ++test) {
86 m[test] = LIBXSMM_UP(m[test], block);
87 n[test] = LIBXSMM_UP(n[test], block);
88 k[test] = LIBXSMM_UP(k[test], block);
89 lda[test] = LIBXSMM_MAX(lda[test], m[test]);
90 ldb[test] = LIBXSMM_MAX(ldb[test], k[test]);
91 ldc[test] = LIBXSMM_MAX(ldc[test], m[test]);
92 }
93 for (test = begin; test < end; ++test) {
94 const libxsmm_blasint size_a = lda[test] * k[test], size_b = ldb[test] * n[test], size_c = ldc[test] * n[test];
95 LIBXSMM_ASSERT(m[test] <= lda[test] && k[test] <= ldb[test] && m[test] <= ldc[test]);
96 max_size_a = LIBXSMM_MAX(max_size_a, size_a);
97 max_size_b = LIBXSMM_MAX(max_size_b, size_b);
98 max_size_c = LIBXSMM_MAX(max_size_c, size_c);
99 }
100 a = (ITYPE*)libxsmm_malloc((size_t)(max_size_a * sizeof(ITYPE)));
101 b = (ITYPE*)libxsmm_malloc((size_t)(max_size_b * sizeof(ITYPE)));
102 c = (OTYPE*)libxsmm_malloc((size_t)(max_size_c * sizeof(OTYPE)));
103 #if defined(GEMM)
104 d = (OTYPE*)libxsmm_malloc((size_t)(max_size_c * sizeof(OTYPE)));
105 LIBXSMM_ASSERT(NULL != d);
106 #endif
107 #if (!defined(__BLAS) || (0 != __BLAS)) && defined(GEMM_GOLD)
108 gold = (OTYPE*)libxsmm_malloc((size_t)(max_size_c * sizeof(OTYPE)));
109 LIBXSMM_ASSERT(NULL != gold);
110 #endif
111 LIBXSMM_ASSERT(NULL != a && NULL != b && NULL != c);
112 LIBXSMM_MATINIT(ITYPE, 42, a, max_size_a, 1, max_size_a, 1.0);
113 LIBXSMM_MATINIT(ITYPE, 24, b, max_size_b, 1, max_size_b, 1.0);
114 #if defined(_DEBUG)
115 libxsmm_matdiff_clear(&diff);
116 #endif
117 for (test = begin; test < end && EXIT_SUCCESS == result; ++test) {
118 for (i = i0; i < i1 && EXIT_SUCCESS == result; ++i) {
119 libxsmm_blasint mi = m[test], ni = n[test], ki = k[test];
120 const int flags = LIBXSMM_GEMM_FLAGS(transa[i], transb[i]);
121 const int smm = SMM_NO_BYPASS(flags, alpha[test], beta[test]);
122 #if defined(CHECK_FPE) && defined(_MM_GET_EXCEPTION_MASK)
123 _MM_SET_EXCEPTION_STATE(0);
124 #endif
125 if ('N' != transa[i] && 'N' == transb[i]) { /* TN */
126 mi = ki = LIBXSMM_MIN(mi, ki);
127 }
128 else if ('N' == transa[i] && 'N' != transb[i]) { /* NT */
129 ki = ni = LIBXSMM_MIN(ki, ni);
130 }
131 else if ('N' != transa[i] && 'N' != transb[i]) { /* TT */
132 const libxsmm_blasint ti = LIBXSMM_MIN(mi, ni);
133 mi = ni = ki = LIBXSMM_MIN(ti, ki);
134 }
135 if (LIBXSMM_FEQ(0, beta[test])) {
136 #if (!defined(__BLAS) || (0 != __BLAS)) && defined(GEMM_GOLD)
137 memset(gold, -1, (size_t)(sizeof(OTYPE) * max_size_c));
138 #endif
139 memset(c, -1, (size_t)(sizeof(OTYPE) * max_size_c));
140 #if defined(GEMM)
141 memset(d, -1, (size_t)(sizeof(OTYPE) * max_size_c));
142 #endif
143 }
144 else {
145 #if (!defined(__BLAS) || (0 != __BLAS)) && defined(GEMM_GOLD)
146 memset(gold, 0, (size_t)(sizeof(OTYPE) * max_size_c));
147 #endif
148 memset(c, 0, (size_t)(sizeof(OTYPE) * max_size_c));
149 #if defined(GEMM)
150 memset(d, 0, (size_t)(sizeof(OTYPE) * max_size_c));
151 #endif
152 }
153 if (0 != smm) {
154 SMM(ITYPE)(transa + i, transb + i, &mi, &ni, &ki,
155 alpha + test, a, lda + test, b, ldb + test, beta + test, c, ldc + test);
156 }
157 #if defined(GEMM)
158 else {
159 GEMM(ITYPE)(transa + i, transb + i, &mi, &ni, &ki,
160 alpha + test, a, lda + test, b, ldb + test, beta + test, c, ldc + test);
161 }
162 # if defined(GEMM2)
163 GEMM2(ITYPE)(transa + i, transb + i, &mi, &ni, &ki,
164 alpha + test, a, lda + test, b, ldb + test, beta + test, d, ldc + test);
165 # else
166 GEMM(ITYPE)(transa + i, transb + i, &mi, &ni, &ki,
167 alpha + test, a, lda + test, b, ldb + test, beta + test, d, ldc + test);
168 # endif
169 #endif
170 #if (0 != LIBXSMM_JIT)
171 if (0 != smm) { /* dispatch kernel and check that it is available */
172 const LIBXSMM_MMFUNCTION_TYPE(ITYPE) kernel = LIBXSMM_MMDISPATCH_SYMBOL(ITYPE)(mi, ni, ki,
173 lda + test, ldb + test, ldc + test, alpha + test, beta + test, &flags, NULL/*prefetch*/);
174 if (NULL == kernel) {
175 # if defined(_DEBUG)
176 fprintf(stderr, "\nERROR: kernel %i.%i not generated!\n\t", test + 1, i + 1);
177 libxsmm_gemm_print(stderr, LIBXSMM_GEMM_PRECISION(ITYPE), transa + i, transb + i, &mi, &ni, &ki,
178 alpha + test, NULL/*a*/, lda + test, NULL/*b*/, ldb + test, beta + test, NULL/*c*/, ldc + test);
179 fprintf(stderr, "\n");
180 # endif
181 result = EXIT_FAILURE;
182 break;
183 }
184 }
185 #endif
186 #if defined(CHECK_FPE) && defined(_MM_GET_EXCEPTION_MASK)
187 fpstate = _MM_GET_EXCEPTION_STATE() & fpcheck;
188 result = (0 == fpstate ? EXIT_SUCCESS : EXIT_FAILURE);
189 if (EXIT_SUCCESS != result) {
190 # if defined(_DEBUG)
191 fprintf(stderr, "FPE(%i.%i): state=0x%08x -> invalid=%s overflow=%s\n", test + 1, i + 1, fpstate,
192 0 != (_MM_MASK_INVALID & fpstate) ? "true" : "false",
193 0 != (_MM_MASK_OVERFLOW & fpstate) ? "true" : "false");
194 # endif
195 }
196 # if (!defined(__BLAS) || (0 != __BLAS)) && defined(GEMM_GOLD)
197 else
198 # endif
199 #endif
200 #if (!defined(__BLAS) || (0 != __BLAS)) && defined(GEMM_GOLD)
201 # if !defined(GEMM)
202 if (0 != smm)
203 # endif
204 {
205 # if defined(GEMM_GOLD)
206 libxsmm_matdiff_info diff_test;
207 GEMM_GOLD(ITYPE)(transa + i, transb + i, &mi, &ni, &ki,
208 alpha + test, a, lda + test, b, ldb + test, beta + test, gold, ldc + test);
209
210 result = libxsmm_matdiff(&diff_test, LIBXSMM_DATATYPE(OTYPE), mi, ni, gold, c, ldc + test, ldc + test);
211 if (EXIT_SUCCESS == result) {
212 # if defined(_DEBUG)
213 libxsmm_matdiff_reduce(&diff, &diff_test);
214 # endif
215 if (1.0 < (1000.0 * diff_test.normf_rel)) {
216 # if defined(_DEBUG)
217 if (0 != smm) {
218 fprintf(stderr, "\nERROR: SMM test %i.%i failed!\n\t", test + 1, i + 1);
219 }
220 else {
221 fprintf(stderr, "\nERROR: test %i.%i failed!\n\t", test + 1, i + 1);
222 }
223 libxsmm_gemm_print(stderr, LIBXSMM_GEMM_PRECISION(ITYPE), transa + i, transb + i, &mi, &ni, &ki,
224 alpha + test, NULL/*a*/, lda + test, NULL/*b*/, ldb + test, beta + test, NULL/*c*/, ldc + test);
225 fprintf(stderr, "\n");
226 # endif
227 result = EXIT_FAILURE;
228 }
229 # if defined(GEMM)
230 else {
231 result = libxsmm_matdiff(&diff_test, LIBXSMM_DATATYPE(OTYPE), mi, ni, gold, d, ldc + test, ldc + test);
232 if (EXIT_SUCCESS == result) {
233 # if defined(_DEBUG)
234 libxsmm_matdiff_reduce(&diff, &diff_test);
235 # endif
236 if (1.0 < (1000.0 * diff_test.normf_rel)) {
237 # if defined(_DEBUG)
238 fprintf(stderr, "\nERROR: test %i.%i failed!\n\t", test + 1, i + 1);
239 libxsmm_gemm_print(stderr, LIBXSMM_GEMM_PRECISION(ITYPE), transa + i, transb + i, &mi, &ni, &ki,
240 alpha + test, NULL/*a*/, lda + test, NULL/*b*/, ldb + test, beta + test, NULL/*c*/, ldc + test);
241 fprintf(stderr, "\n");
242 # endif
243 result = EXIT_FAILURE;
244 }
245 }
246 }
247 # endif
248 }
249 # endif
250 }
251 # if defined(GEMM_GOLD)
252 /* avoid drift between Gold and test-results */
253 memcpy(c, gold, (size_t)(sizeof(OTYPE) * max_size_c));
254 # if defined(GEMM)
255 memcpy(d, gold, (size_t)(sizeof(OTYPE) * max_size_c));
256 # endif
257 # endif
258 #elif defined(_DEBUG)
259 fprintf(stderr, "Warning: skipped the test due to missing BLAS support!\n");
260 #endif
261 }
262 }
263
264 #if defined(_DEBUG)
265 fprintf(stderr, "Diff: L2abs=%f Linf=%f\n", diff.l2_abs, diff.linf_abs);
266 #endif
267 #if defined(CHECK_FPE) && defined(_MM_GET_EXCEPTION_MASK)
268 _MM_SET_EXCEPTION_MASK(fpemask); /* restore FPE mask */
269 _MM_SET_EXCEPTION_STATE(0); /* clear FPE state */
270 #endif
271 libxsmm_free(a);
272 libxsmm_free(b);
273 libxsmm_free(c);
274 #if defined(GEMM)
275 libxsmm_free(d);
276 #endif
277 #if (!defined(__BLAS) || (0 != __BLAS)) && defined(GEMM_GOLD)
278 libxsmm_free(gold);
279 #endif
280 return result;
281 }
282
283