1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved. *
3 * This file is part of the LIBXSMM library. *
4 * *
5 * For information on the license, see the LICENSE file. *
6 * Further information: https://github.com/hfp/libxsmm/ *
7 * SPDX-License-Identifier: BSD-3-Clause *
8 ******************************************************************************/
9 /* Kunal Banerjee (Intel Corp.), Dheevatsa Mudigere (Intel Corp.)
10 Alexander Heinecke (Intel Corp.), Hans Pabst (Intel Corp.)
11 ******************************************************************************/
12 #include <libxsmm.h>
13
14 #if defined(LIBXSMM_OFFLOAD_TARGET)
15 # pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
16 #endif
17 #include <stdlib.h>
18 #include <string.h>
19 #include <stdio.h>
20 #if defined(_OPENMP)
21 # include <omp.h>
22 #endif
23 #if defined(__MKL)
24 # include <mkl_service.h>
25 #endif
26 #if defined(LIBXSMM_OFFLOAD_TARGET)
27 # pragma offload_attribute(pop)
28 #endif
29
30 #if !defined(ITYPE)
31 # define ITYPE float
32 #endif
33
34 #if !defined(CHECK) && (LIBXSMM_EQUAL(ITYPE, float) || LIBXSMM_EQUAL(ITYPE, double))
35 # if !defined(MKL_DIRECT_CALL_SEQ) && !defined(MKL_DIRECT_CALL)
LIBXSMM_BLAS_SYMBOL_DECL(ITYPE,gemm)36 LIBXSMM_BLAS_SYMBOL_DECL(ITYPE, gemm)
37 # endif
38 # define CHECK
39 #endif
40
41 #define MYASSERT(x) if (!(x)) { printf("Assertion %s failed...\n", #x); exit(1);}
42
43
44 int main(int argc, char* argv[])
45 {
46 LIBXSMM_BLAS_CONST libxsmm_blasint m = (1 < argc ? atoi(argv[1]) : 1024);
47 LIBXSMM_BLAS_CONST libxsmm_blasint k = (3 < argc ? atoi(argv[3]) : m);
48 LIBXSMM_BLAS_CONST libxsmm_blasint n = (2 < argc ? atoi(argv[2]) : k);
49 const libxsmm_blasint bm = (4 < argc ? atoi(argv[4]) : 32);
50 const libxsmm_blasint bk = (6 < argc ? atoi(argv[6]) : bm);
51 const libxsmm_blasint bn = (5 < argc ? atoi(argv[5]) : bk);
52 const libxsmm_blocked_gemm_order order = (libxsmm_blocked_gemm_order)(7 < argc ? atoi(argv[7]) : 0);
53 const int nrepeat = (8 < argc ? atoi(argv[8]) : 100);
54 const libxsmm_blasint b_m1 = (9 < argc ? atoi(argv[9]) : 1);
55 const libxsmm_blasint b_n1 = (10 < argc ? atoi(argv[10]) : 1);
56 const libxsmm_blasint b_k1 = (11 < argc ? atoi(argv[11]) : 1);
57 const libxsmm_blasint b_k2 = (12 < argc ? atoi(argv[12]) : 1);
58 const int ab = (13 < argc ? atoi(argv[13]) : 0);
59 LIBXSMM_BLAS_CONST libxsmm_blasint lda = (14 < argc ? atoi(argv[13]) : m);
60 LIBXSMM_BLAS_CONST libxsmm_blasint ldb = (15 < argc ? atoi(argv[14]) : k);
61 LIBXSMM_BLAS_CONST libxsmm_blasint ldc = (16 < argc ? atoi(argv[15]) : m);
62 LIBXSMM_BLAS_CONST char transa = 'N', transb = 'N'; /* no transposes */
63 LIBXSMM_BLAS_CONST ITYPE alpha = 1, beta = 1;
64 const int gemm_flags = LIBXSMM_GEMM_FLAGS(transa, transb);
65 const double gflops = 2.0 * m * n * k * 1E-9;
66 int result = EXIT_SUCCESS;
67 #if defined(CHECK) && (!defined(__BLAS) || (0 != __BLAS))
68 const char *const env_check = getenv("CHECK");
69 const double check = LIBXSMM_ABS(NULL == env_check ? 0 : atof(env_check));
70 #endif
71 if (argc > 1 && !strncmp(argv[1], "-h", 3)) { /* check command line */
72 printf("\nUsage: ./bgemm [M] [N] [K] [bm] [bn] [bk] [order] [reps] [b_m1] [b_n1] [b_k1] [b_k2] [verbose]\n\n");
73 return result;
74 }
75
76 MYASSERT(m % b_m1 == 0);
77 MYASSERT(n % b_n1 == 0);
78 MYASSERT(k % b_k1 == 0);
79 MYASSERT(m/b_m1 % bm == 0);
80 MYASSERT(n/b_n1 % bn == 0);
81 MYASSERT(k/b_k1/b_k2 % bk == 0);
82
83 #if defined(LIBXSMM_OFFLOAD_TARGET)
84 # pragma offload target(LIBXSMM_OFFLOAD_TARGET)
85 #endif
86 {
87 ITYPE* agold = (ITYPE*)libxsmm_malloc((size_t)lda * (size_t)k * sizeof(ITYPE));
88 ITYPE* bgold = (ITYPE*)libxsmm_malloc((size_t)ldb * (size_t)n * sizeof(ITYPE));
89 ITYPE* cgold = (ITYPE*)libxsmm_malloc((size_t)ldc * (size_t)n * sizeof(ITYPE));
90 ITYPE* a = (ITYPE*)libxsmm_malloc((size_t)m * (size_t)k * sizeof(ITYPE));
91 ITYPE* b = (ITYPE*)libxsmm_malloc((size_t)k * (size_t)n * sizeof(ITYPE));
92 ITYPE* c = (ITYPE*)libxsmm_malloc((size_t)m * (size_t)n * sizeof(ITYPE));
93 libxsmm_blocked_gemm_handle* handle = 0;
94 unsigned long long start;
95 double duration;
96 #if defined(_OPENMP)
97 const int nthreads = omp_get_max_threads();
98 #else
99 const int nthreads = 1;
100 #endif
101 handle = libxsmm_blocked_gemm_handle_create(nthreads,
102 LIBXSMM_GEMM_PRECISION(ITYPE), LIBXSMM_GEMM_PRECISION(ITYPE),
103 m, n, k, &bm, &bn, &bk, &b_m1, &b_n1, &b_k1, &b_k2,
104 &alpha, &beta, &gemm_flags, NULL/*auto-prefetch*/, &order);
105
106 if (0 != handle) {
107 LIBXSMM_MATINIT_OMP(ITYPE, 42, agold, m, k, lda, 1.0);
108 LIBXSMM_MATINIT_OMP(ITYPE, 24, bgold, k, n, ldb, 1.0);
109 LIBXSMM_MATINIT_OMP(ITYPE, 0, cgold, m, n, ldc, 1.0);
110 libxsmm_blocked_gemm_copyin_a(handle, agold, &lda, a);
111 libxsmm_blocked_gemm_copyin_b(handle, bgold, &ldb, b);
112 libxsmm_blocked_gemm_copyin_c(handle, cgold, &ldc, c);
113 #if defined(MKL_ENABLE_AVX512)
114 mkl_enable_instructions(MKL_ENABLE_AVX512);
115 #endif
116 /* warm-up OpenMP (populate thread pool) */
117 libxsmm_blocked_gemm_omp(handle, a, b, c, 1);
118 #if defined(CHECK) && (!defined(__BLAS) || (0 != __BLAS))
119 if (!LIBXSMM_FEQ(0, check)) {
120 LIBXSMM_GEMM_SYMBOL(ITYPE)(&transa, &transb, &m, &n, &k, &alpha, agold, &lda, bgold, &ldb, &beta, cgold, &ldc);
121 }
122 #endif
123 if (!ab) {
124 libxsmm_gemm_print(stdout, LIBXSMM_GEMM_PRECISION(ITYPE),
125 &transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc);
126 fprintf(stdout, "\n\n");
127 }
128 start = libxsmm_timer_tick();
129 libxsmm_blocked_gemm_omp(handle, a, b, c, nrepeat);
130 duration = libxsmm_timer_duration(start, libxsmm_timer_tick());
131 if (0 < duration) {
132 if (ab) {
133 fprintf(stdout, "\tLIBXSMM: %.1f GFLOPS/s | %lli,%lli,%lli,%lli,%lli,%lli,%i,%lli,%lli,%lli,%lli\n",
134 gflops * nrepeat / duration, (long long)m, (long long)n, (long long)k, (long long)bm, (long long)bn, (long long)bk,
135 (int)order, (long long)b_m1, (long long)b_n1, (long long)b_k1, (long long)b_k2);
136 } else {
137 fprintf(stdout, "\tLIBXSMM: %.1f GFLOPS/s\n", gflops * nrepeat / duration);
138 }
139 }
140 #if defined(CHECK) && (!defined(__BLAS) || (0 != __BLAS))
141 if (!LIBXSMM_FEQ(0, check)) { /* validate result against LAPACK/BLAS xGEMM */
142 ITYPE* ctest = 0;
143 int i;
144 start = libxsmm_timer_tick();
145 for (i = 0; i < nrepeat; ++i) {
146 LIBXSMM_GEMM_SYMBOL(ITYPE)(&transa, &transb, &m, &n, &k, &alpha, agold, &lda, bgold, &ldb, &beta, cgold, &ldc);
147 }
148 duration = libxsmm_timer_duration(start, libxsmm_timer_tick());
149 if (0 < duration) {
150 fprintf(stdout, "\tBLAS: %.1f GFLOPS/s\n", gflops * nrepeat / duration);
151 }
152 /* free memory not needed further; avoid double-free later on */
153 libxsmm_free(agold); agold = 0;
154 libxsmm_free(bgold); bgold = 0;
155 libxsmm_free(a); a = 0;
156 libxsmm_free(b); b = 0;
157 /* allocate C-matrix in regular format, and perform copy-out */
158 ctest = (ITYPE*)libxsmm_malloc((size_t)(sizeof(ITYPE) * ldc * n));
159 if (0 != ctest) {
160 libxsmm_matdiff_info diff;
161 libxsmm_blocked_gemm_copyout_c(handle, c, &ldc, ctest);
162 result = libxsmm_matdiff(&diff, LIBXSMM_DATATYPE(ITYPE), m, n, cgold, ctest, &ldc, &ldc);
163 if (EXIT_SUCCESS == result) {
164 fprintf(stdout, "\tdiff: L2abs=%f Linf=%f\n", diff.l2_abs, diff.linf_abs);
165 if (check < 100.0 * diff.normf_rel) {
166 fprintf(stderr, "FAILED with an error of %f%%!\n", 100.0 * diff.normf_rel);
167 result = EXIT_FAILURE;
168 }
169 }
170 libxsmm_free(ctest);
171 }
172 }
173 #endif
174 libxsmm_blocked_gemm_handle_destroy(handle);
175 }
176 else {
177 fprintf(stderr, "FAILED to create BGEMM-handle! For details retry with LIBXSMM_VERBOSE=1.\n");
178 result = EXIT_FAILURE;
179 }
180 libxsmm_free(agold);
181 libxsmm_free(bgold);
182 libxsmm_free(cgold);
183 libxsmm_free(a);
184 libxsmm_free(b);
185 libxsmm_free(c);
186 }
187 if (!ab) {
188 fprintf(stdout, "Finished\n");
189 }
190 return result;
191 }
192
193