1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Hans Pabst (Intel Corp.)
10 ******************************************************************************/
11 #include <libxsmm.h>
12 
13 #if defined(LIBXSMM_OFFLOAD_TARGET)
14 # pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
15 #endif
16 #if defined(__MKL)
17 # include <mkl_service.h>
18 #endif
19 #include <stdlib.h>
20 #include <stdio.h>
21 #if defined(LIBXSMM_OFFLOAD_TARGET)
22 # pragma offload_attribute(pop)
23 #endif
24 
25 #if !defined(ITYPE)
26 # define ITYPE double
27 #endif
28 #if !defined(OTYPE)
29 # define OTYPE ITYPE
30 #endif
31 
32 #if !defined(SEQUENTIAL) && 0
33 # define SEQUENTIAL
34 #endif
35 
36 #if !defined(XGEMM)
37 # if defined(SEQUENTIAL)
38 #   define XGEMM(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) \
39       libxsmm_xgemm(LIBXSMM_GEMM_PRECISION(ITYPE), LIBXSMM_GEMM_PRECISION(OTYPE), \
40         TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC)
41 # else
42 #   define XGEMM(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) \
43       LIBXSMM_YGEMM_SYMBOL(ITYPE)(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC)
44 # endif
45 #endif
46 
47 #if !defined(CHECK) && (LIBXSMM_EQUAL(ITYPE, float) || LIBXSMM_EQUAL(ITYPE, double))
48 # if !defined(MKL_DIRECT_CALL_SEQ) && !defined(MKL_DIRECT_CALL)
LIBXSMM_BLAS_SYMBOL_DECL(ITYPE,gemm)49 LIBXSMM_BLAS_SYMBOL_DECL(ITYPE, gemm)
50 # endif
51 # define XGEMM_GOLD(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) \
52     LIBXSMM_GEMM_SYMBOL(ITYPE)(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC)
53 # define CHECK
54 #endif
55 
56 
57 int main(int argc, char* argv[])
58 {
59   LIBXSMM_BLAS_CONST libxsmm_blasint m = (1 < argc ? atoi(argv[1]) : 512);
60   LIBXSMM_BLAS_CONST libxsmm_blasint k = (3 < argc ? atoi(argv[3]) : m);
61   LIBXSMM_BLAS_CONST libxsmm_blasint n = (2 < argc ? atoi(argv[2]) : k), nn = n;
62   LIBXSMM_BLAS_CONST OTYPE alpha = (OTYPE)(7 < argc ? atof(argv[7]) : 1.0);
63   LIBXSMM_BLAS_CONST OTYPE beta  = (OTYPE)(8 < argc ? atof(argv[8]) : 1.0);
64   LIBXSMM_BLAS_CONST char transa = (/*LIBXSMM_BLAS_CONST*/ char)( 9 < argc ? *argv[9]  : 'N');
65   LIBXSMM_BLAS_CONST char transb = (/*LIBXSMM_BLAS_CONST*/ char)(10 < argc ? *argv[10] : 'N');
66   LIBXSMM_BLAS_CONST libxsmm_blasint mm = (('N' == transa || 'n' == transa) ? m : k);
67   LIBXSMM_BLAS_CONST libxsmm_blasint kk = (('N' == transb || 'n' == transb) ? k : n);
68   LIBXSMM_BLAS_CONST libxsmm_blasint ka = (('N' == transa || 'n' == transa) ? k : m);
69   LIBXSMM_BLAS_CONST libxsmm_blasint kb = (('N' == transb || 'n' == transb) ? n : k);
70   LIBXSMM_BLAS_CONST libxsmm_blasint lda = ((4 < argc && mm < atoi(argv[4])) ? atoi(argv[4]) : mm);
71   LIBXSMM_BLAS_CONST libxsmm_blasint ldb = ((5 < argc && kk < atoi(argv[5])) ? atoi(argv[5]) : kk);
72   LIBXSMM_BLAS_CONST libxsmm_blasint ldc = ((6 < argc && m < atoi(argv[6])) ? atoi(argv[6]) : m);
73   const int nrepeat = ((11 < argc && 0 < atoi(argv[11])) ? atoi(argv[11])
74     : LIBXSMM_MAX(13 / LIBXSMM_MAX(1, (int)(libxsmm_icbrt_u64(1ULL * m * n * k) >> 10)), 3));
75   const double gflops = 2.0 * m * n * k * 1E-9;
76   int result = EXIT_SUCCESS;
77 #if defined(CHECK)
78   const char *const env_check = getenv("CHECK");
79   const double check = LIBXSMM_ABS(NULL == env_check ? 0 : atof(env_check));
80 #endif
81 #if defined(LIBXSMM_OFFLOAD_TARGET)
82 # pragma offload target(LIBXSMM_OFFLOAD_TARGET)
83 #endif
84   {
85     const char *const env_tasks = getenv("TASKS");
86     const int tasks = (NULL == env_tasks || 0 == *env_tasks) ? 0/*default*/ : atoi(env_tasks);
87     ITYPE *const a = (ITYPE*)libxsmm_malloc((size_t)(lda * ka * sizeof(ITYPE)));
88     ITYPE *const b = (ITYPE*)libxsmm_malloc((size_t)(ldb * kb * sizeof(ITYPE)));
89     OTYPE *const c = (OTYPE*)libxsmm_malloc((size_t)(ldc * nn * sizeof(OTYPE)));
90 #if defined(CHECK)
91     OTYPE* d = 0;
92     if (!LIBXSMM_FEQ(0, check)) {
93       d = (OTYPE*)libxsmm_malloc((size_t)(ldc * nn * sizeof(OTYPE)));
94       LIBXSMM_MATINIT_OMP(OTYPE, 0, d, m, n, ldc, 1.0);
95     }
96 #endif
97     LIBXSMM_MATINIT_OMP(OTYPE,  0, c,  m,  n, ldc, 1.0);
98     LIBXSMM_MATINIT_OMP(ITYPE, 42, a, mm, ka, lda, 1.0);
99     LIBXSMM_MATINIT_OMP(ITYPE, 24, b, kk, kb, ldb, 1.0);
100 #if defined(MKL_ENABLE_AVX512)
101     mkl_enable_instructions(MKL_ENABLE_AVX512);
102 #endif
103     /* warm-up OpenMP (populate thread pool) */
104 #if defined(CHECK) && (!defined(__BLAS) || (0 != __BLAS))
105     if (0 != d) XGEMM_GOLD(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, d, &ldc);
106 #endif
107     XGEMM(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc);
108     libxsmm_gemm_print(stdout, LIBXSMM_GEMM_PRECISION(ITYPE),
109       &transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc);
110     fprintf(stdout, "\n\n");
111 
112     if (0 == tasks) { /* tiled xGEMM (with library-internal parallelization) */
113       int i; double duration;
114       unsigned long long start = libxsmm_timer_tick();
115       for (i = 0; i < nrepeat; ++i) {
116         XGEMM(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc);
117       }
118       duration = libxsmm_timer_duration(start, libxsmm_timer_tick());
119       if (0 < duration) {
120         fprintf(stdout, "\tLIBXSMM: %.1f GFLOPS/s\n", gflops * nrepeat / duration);
121       }
122     }
123     else { /* tiled xGEMM (with external parallelization) */
124       int i; double duration;
125       unsigned long long start = libxsmm_timer_tick();
126       for (i = 0; i < nrepeat; ++i) {
127 #if defined(_OPENMP)
128 #       pragma omp parallel
129 #       pragma omp single nowait
130 #endif
131         XGEMM(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc);
132       }
133       duration = libxsmm_timer_duration(start, libxsmm_timer_tick());
134       if (0 < duration) {
135         fprintf(stdout, "\tLIBXSMM: %.1f GFLOPS/s\n", gflops * nrepeat / duration);
136       }
137     }
138 #if defined(CHECK) && (!defined(__BLAS) || (0 != __BLAS))
139     if (0 != d) { /* validate result against LAPACK/BLAS xGEMM */
140       libxsmm_matdiff_info diff;
141       int i; double duration;
142       unsigned long long start = libxsmm_timer_tick();
143       for (i = 0; i < nrepeat; ++i) {
144         XGEMM_GOLD(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, d, &ldc);
145       }
146       duration = libxsmm_timer_duration(start, libxsmm_timer_tick());
147 
148       if (0 < duration) {
149         fprintf(stdout, "\tBLAS: %.1f GFLOPS/s\n", gflops * nrepeat / duration);
150       }
151       result = libxsmm_matdiff(&diff, LIBXSMM_DATATYPE(OTYPE), m, n, d, c, &ldc, &ldc);
152       if (EXIT_SUCCESS == result) {
153         fprintf(stdout, "\tdiff: L2abs=%f Linf=%f\n", diff.l2_abs, diff.linf_abs);
154         if (check < diff.l2_rel) {
155           fprintf(stderr, "FAILED.\n");
156           result = EXIT_FAILURE;
157         }
158       }
159       libxsmm_free(d);
160     }
161 #endif
162     libxsmm_free(c);
163     libxsmm_free(a);
164     libxsmm_free(b);
165   }
166   fprintf(stdout, "Finished\n");
167   return result;
168 }
169 
170