1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Kunal Banerjee (Intel Corp.), Dheevatsa Mudigere (Intel Corp.)
10    Alexander Heinecke (Intel Corp.), Hans Pabst (Intel Corp.)
11 ******************************************************************************/
12 #include <libxsmm.h>
13 
14 #if defined(LIBXSMM_OFFLOAD_TARGET)
15 # pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
16 #endif
17 #include <stdlib.h>
18 #include <string.h>
19 #include <stdio.h>
20 #if defined(_OPENMP)
21 # include <omp.h>
22 #endif
23 #if defined(__MKL)
24 # include <mkl_service.h>
25 #endif
26 #if defined(LIBXSMM_OFFLOAD_TARGET)
27 # pragma offload_attribute(pop)
28 #endif
29 
30 #if !defined(ITYPE)
31 # define ITYPE float
32 #endif
33 
34 #if !defined(CHECK) && (LIBXSMM_EQUAL(ITYPE, float) || LIBXSMM_EQUAL(ITYPE, double))
35 # if !defined(MKL_DIRECT_CALL_SEQ) && !defined(MKL_DIRECT_CALL)
LIBXSMM_BLAS_SYMBOL_DECL(ITYPE,gemm)36 LIBXSMM_BLAS_SYMBOL_DECL(ITYPE, gemm)
37 # endif
38 # define CHECK
39 #endif
40 
41 #define MYASSERT(x) if (!(x)) { printf("Assertion %s failed...\n", #x); exit(1);}
42 
43 
44 int main(int argc, char* argv[])
45 {
46   LIBXSMM_BLAS_CONST libxsmm_blasint m = (1 < argc ? atoi(argv[1]) : 1024);
47   LIBXSMM_BLAS_CONST libxsmm_blasint k = (3 < argc ? atoi(argv[3]) : m);
48   LIBXSMM_BLAS_CONST libxsmm_blasint n = (2 < argc ? atoi(argv[2]) : k);
49   const libxsmm_blasint bm = (4 < argc ? atoi(argv[4]) : 32);
50   const libxsmm_blasint bk = (6 < argc ? atoi(argv[6]) : bm);
51   const libxsmm_blasint bn = (5 < argc ? atoi(argv[5]) : bk);
52   const libxsmm_blocked_gemm_order order = (libxsmm_blocked_gemm_order)(7 < argc ? atoi(argv[7]) : 0);
53   const int nrepeat = (8 < argc ? atoi(argv[8]) : 100);
54   const libxsmm_blasint b_m1 = (9 < argc ? atoi(argv[9]) : 1);
55   const libxsmm_blasint b_n1  = (10 < argc ? atoi(argv[10]) : 1);
56   const libxsmm_blasint b_k1 = (11 < argc ? atoi(argv[11]) : 1);
57   const libxsmm_blasint b_k2 = (12 < argc ? atoi(argv[12]) : 1);
58   const int ab = (13 < argc ? atoi(argv[13]) : 0);
59   LIBXSMM_BLAS_CONST libxsmm_blasint lda = (14 < argc ? atoi(argv[13]) : m);
60   LIBXSMM_BLAS_CONST libxsmm_blasint ldb = (15 < argc ? atoi(argv[14]) : k);
61   LIBXSMM_BLAS_CONST libxsmm_blasint ldc = (16 < argc ? atoi(argv[15]) : m);
62   LIBXSMM_BLAS_CONST char transa = 'N', transb = 'N'; /* no transposes */
63   LIBXSMM_BLAS_CONST ITYPE alpha = 1, beta = 1;
64   const int gemm_flags = LIBXSMM_GEMM_FLAGS(transa, transb);
65   const double gflops = 2.0 * m * n * k * 1E-9;
66   int result = EXIT_SUCCESS;
67 #if defined(CHECK) && (!defined(__BLAS) || (0 != __BLAS))
68   const char *const env_check = getenv("CHECK");
69   const double check = LIBXSMM_ABS(NULL == env_check ? 0 : atof(env_check));
70 #endif
71   if (argc > 1 && !strncmp(argv[1], "-h", 3)) { /* check command line */
72     printf("\nUsage: ./bgemm [M] [N] [K] [bm] [bn] [bk] [order] [reps] [b_m1] [b_n1] [b_k1] [b_k2] [verbose]\n\n");
73     return result;
74   }
75 
76   MYASSERT(m % b_m1 == 0);
77   MYASSERT(n % b_n1 == 0);
78   MYASSERT(k % b_k1 == 0);
79   MYASSERT(m/b_m1 % bm == 0);
80   MYASSERT(n/b_n1 % bn == 0);
81   MYASSERT(k/b_k1/b_k2 % bk == 0);
82 
83 #if defined(LIBXSMM_OFFLOAD_TARGET)
84 # pragma offload target(LIBXSMM_OFFLOAD_TARGET)
85 #endif
86   {
87     ITYPE* agold = (ITYPE*)libxsmm_malloc((size_t)lda * (size_t)k * sizeof(ITYPE));
88     ITYPE* bgold = (ITYPE*)libxsmm_malloc((size_t)ldb * (size_t)n * sizeof(ITYPE));
89     ITYPE* cgold = (ITYPE*)libxsmm_malloc((size_t)ldc * (size_t)n * sizeof(ITYPE));
90     ITYPE* a = (ITYPE*)libxsmm_malloc((size_t)m * (size_t)k * sizeof(ITYPE));
91     ITYPE* b = (ITYPE*)libxsmm_malloc((size_t)k * (size_t)n * sizeof(ITYPE));
92     ITYPE* c = (ITYPE*)libxsmm_malloc((size_t)m * (size_t)n * sizeof(ITYPE));
93     libxsmm_blocked_gemm_handle* handle = 0;
94     unsigned long long start;
95     double duration;
96 #if defined(_OPENMP)
97     const int nthreads = omp_get_max_threads();
98 #else
99     const int nthreads = 1;
100 #endif
101     handle = libxsmm_blocked_gemm_handle_create(nthreads,
102       LIBXSMM_GEMM_PRECISION(ITYPE), LIBXSMM_GEMM_PRECISION(ITYPE),
103       m, n, k, &bm, &bn, &bk, &b_m1, &b_n1, &b_k1, &b_k2,
104       &alpha, &beta, &gemm_flags, NULL/*auto-prefetch*/, &order);
105 
106     if (0 != handle) {
107       LIBXSMM_MATINIT_OMP(ITYPE, 42, agold, m, k, lda, 1.0);
108       LIBXSMM_MATINIT_OMP(ITYPE, 24, bgold, k, n, ldb, 1.0);
109       LIBXSMM_MATINIT_OMP(ITYPE,  0, cgold, m, n, ldc, 1.0);
110       libxsmm_blocked_gemm_copyin_a(handle, agold, &lda, a);
111       libxsmm_blocked_gemm_copyin_b(handle, bgold, &ldb, b);
112       libxsmm_blocked_gemm_copyin_c(handle, cgold, &ldc, c);
113 #if defined(MKL_ENABLE_AVX512)
114       mkl_enable_instructions(MKL_ENABLE_AVX512);
115 #endif
116       /* warm-up OpenMP (populate thread pool) */
117       libxsmm_blocked_gemm_omp(handle, a, b, c, 1);
118 #if defined(CHECK) && (!defined(__BLAS) || (0 != __BLAS))
119       if (!LIBXSMM_FEQ(0, check)) {
120         LIBXSMM_GEMM_SYMBOL(ITYPE)(&transa, &transb, &m, &n, &k, &alpha, agold, &lda, bgold, &ldb, &beta, cgold, &ldc);
121       }
122 #endif
123       if (!ab) {
124       libxsmm_gemm_print(stdout, LIBXSMM_GEMM_PRECISION(ITYPE),
125         &transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc);
126       fprintf(stdout, "\n\n");
127       }
128       start = libxsmm_timer_tick();
129       libxsmm_blocked_gemm_omp(handle, a, b, c, nrepeat);
130       duration = libxsmm_timer_duration(start, libxsmm_timer_tick());
131       if (0 < duration) {
132         if (ab) {
133           fprintf(stdout, "\tLIBXSMM: %.1f GFLOPS/s | %lli,%lli,%lli,%lli,%lli,%lli,%i,%lli,%lli,%lli,%lli\n",
134             gflops * nrepeat / duration, (long long)m, (long long)n, (long long)k, (long long)bm, (long long)bn, (long long)bk,
135             (int)order, (long long)b_m1, (long long)b_n1, (long long)b_k1, (long long)b_k2);
136         } else {
137           fprintf(stdout, "\tLIBXSMM: %.1f GFLOPS/s\n", gflops * nrepeat / duration);
138         }
139       }
140 #if defined(CHECK) && (!defined(__BLAS) || (0 != __BLAS))
141       if (!LIBXSMM_FEQ(0, check)) { /* validate result against LAPACK/BLAS xGEMM */
142         ITYPE* ctest = 0;
143         int i;
144         start = libxsmm_timer_tick();
145         for (i = 0; i < nrepeat; ++i) {
146           LIBXSMM_GEMM_SYMBOL(ITYPE)(&transa, &transb, &m, &n, &k, &alpha, agold, &lda, bgold, &ldb, &beta, cgold, &ldc);
147         }
148         duration = libxsmm_timer_duration(start, libxsmm_timer_tick());
149         if (0 < duration) {
150           fprintf(stdout, "\tBLAS: %.1f GFLOPS/s\n", gflops * nrepeat / duration);
151         }
152         /* free memory not needed further; avoid double-free later on */
153         libxsmm_free(agold); agold = 0;
154         libxsmm_free(bgold); bgold = 0;
155         libxsmm_free(a); a = 0;
156         libxsmm_free(b); b = 0;
157         /* allocate C-matrix in regular format, and perform copy-out */
158         ctest = (ITYPE*)libxsmm_malloc((size_t)(sizeof(ITYPE) * ldc * n));
159         if (0 != ctest) {
160           libxsmm_matdiff_info diff;
161           libxsmm_blocked_gemm_copyout_c(handle, c, &ldc, ctest);
162           result = libxsmm_matdiff(&diff, LIBXSMM_DATATYPE(ITYPE), m, n, cgold, ctest, &ldc, &ldc);
163           if (EXIT_SUCCESS == result) {
164             fprintf(stdout, "\tdiff: L2abs=%f Linf=%f\n", diff.l2_abs, diff.linf_abs);
165             if (check < 100.0 * diff.normf_rel) {
166               fprintf(stderr, "FAILED with an error of %f%%!\n", 100.0 * diff.normf_rel);
167               result = EXIT_FAILURE;
168             }
169           }
170           libxsmm_free(ctest);
171         }
172       }
173 #endif
174       libxsmm_blocked_gemm_handle_destroy(handle);
175     }
176     else {
177       fprintf(stderr, "FAILED to create BGEMM-handle! For details retry with LIBXSMM_VERBOSE=1.\n");
178       result = EXIT_FAILURE;
179     }
180     libxsmm_free(agold);
181     libxsmm_free(bgold);
182     libxsmm_free(cgold);
183     libxsmm_free(a);
184     libxsmm_free(b);
185     libxsmm_free(c);
186   }
187   if (!ab) {
188     fprintf(stdout, "Finished\n");
189   }
190   return result;
191 }
192 
193