1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved. *
3 * This file is part of the LIBXSMM library. *
4 * *
5 * For information on the license, see the LICENSE file. *
6 * Further information: https://github.com/hfp/libxsmm/ *
7 * SPDX-License-Identifier: BSD-3-Clause *
8 ******************************************************************************/
9 /* Hans Pabst (Intel Corp.)
10 ******************************************************************************/
11 #include <libxsmm.h>
12
13 #if defined(LIBXSMM_OFFLOAD_TARGET)
14 # pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
15 #endif
16 #if defined(__MKL)
17 # include <mkl_service.h>
18 #endif
19 #include <stdlib.h>
20 #include <stdio.h>
21 #if defined(LIBXSMM_OFFLOAD_TARGET)
22 # pragma offload_attribute(pop)
23 #endif
24
25 #if !defined(ITYPE)
26 # define ITYPE double
27 #endif
28 #if !defined(OTYPE)
29 # define OTYPE ITYPE
30 #endif
31
32 #if !defined(SEQUENTIAL) && 0
33 # define SEQUENTIAL
34 #endif
35
36 #if !defined(XGEMM)
37 # if defined(SEQUENTIAL)
38 # define XGEMM(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) \
39 libxsmm_xgemm(LIBXSMM_GEMM_PRECISION(ITYPE), LIBXSMM_GEMM_PRECISION(OTYPE), \
40 TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC)
41 # else
42 # define XGEMM(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) \
43 LIBXSMM_YGEMM_SYMBOL(ITYPE)(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC)
44 # endif
45 #endif
46
47 #if !defined(CHECK) && (LIBXSMM_EQUAL(ITYPE, float) || LIBXSMM_EQUAL(ITYPE, double))
48 # if !defined(MKL_DIRECT_CALL_SEQ) && !defined(MKL_DIRECT_CALL)
LIBXSMM_BLAS_SYMBOL_DECL(ITYPE,gemm)49 LIBXSMM_BLAS_SYMBOL_DECL(ITYPE, gemm)
50 # endif
51 # define XGEMM_GOLD(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC) \
52 LIBXSMM_GEMM_SYMBOL(ITYPE)(TRANSA, TRANSB, M, N, K, ALPHA, A, LDA, B, LDB, BETA, C, LDC)
53 # define CHECK
54 #endif
55
56
57 int main(int argc, char* argv[])
58 {
59 LIBXSMM_BLAS_CONST libxsmm_blasint m = (1 < argc ? atoi(argv[1]) : 512);
60 LIBXSMM_BLAS_CONST libxsmm_blasint k = (3 < argc ? atoi(argv[3]) : m);
61 LIBXSMM_BLAS_CONST libxsmm_blasint n = (2 < argc ? atoi(argv[2]) : k), nn = n;
62 LIBXSMM_BLAS_CONST OTYPE alpha = (OTYPE)(7 < argc ? atof(argv[7]) : 1.0);
63 LIBXSMM_BLAS_CONST OTYPE beta = (OTYPE)(8 < argc ? atof(argv[8]) : 1.0);
64 LIBXSMM_BLAS_CONST char transa = (/*LIBXSMM_BLAS_CONST*/ char)( 9 < argc ? *argv[9] : 'N');
65 LIBXSMM_BLAS_CONST char transb = (/*LIBXSMM_BLAS_CONST*/ char)(10 < argc ? *argv[10] : 'N');
66 LIBXSMM_BLAS_CONST libxsmm_blasint mm = (('N' == transa || 'n' == transa) ? m : k);
67 LIBXSMM_BLAS_CONST libxsmm_blasint kk = (('N' == transb || 'n' == transb) ? k : n);
68 LIBXSMM_BLAS_CONST libxsmm_blasint ka = (('N' == transa || 'n' == transa) ? k : m);
69 LIBXSMM_BLAS_CONST libxsmm_blasint kb = (('N' == transb || 'n' == transb) ? n : k);
70 LIBXSMM_BLAS_CONST libxsmm_blasint lda = ((4 < argc && mm < atoi(argv[4])) ? atoi(argv[4]) : mm);
71 LIBXSMM_BLAS_CONST libxsmm_blasint ldb = ((5 < argc && kk < atoi(argv[5])) ? atoi(argv[5]) : kk);
72 LIBXSMM_BLAS_CONST libxsmm_blasint ldc = ((6 < argc && m < atoi(argv[6])) ? atoi(argv[6]) : m);
73 const int nrepeat = ((11 < argc && 0 < atoi(argv[11])) ? atoi(argv[11])
74 : LIBXSMM_MAX(13 / LIBXSMM_MAX(1, (int)(libxsmm_icbrt_u64(1ULL * m * n * k) >> 10)), 3));
75 const double gflops = 2.0 * m * n * k * 1E-9;
76 int result = EXIT_SUCCESS;
77 #if defined(CHECK)
78 const char *const env_check = getenv("CHECK");
79 const double check = LIBXSMM_ABS(NULL == env_check ? 0 : atof(env_check));
80 #endif
81 #if defined(LIBXSMM_OFFLOAD_TARGET)
82 # pragma offload target(LIBXSMM_OFFLOAD_TARGET)
83 #endif
84 {
85 const char *const env_tasks = getenv("TASKS");
86 const int tasks = (NULL == env_tasks || 0 == *env_tasks) ? 0/*default*/ : atoi(env_tasks);
87 ITYPE *const a = (ITYPE*)libxsmm_malloc((size_t)(lda * ka * sizeof(ITYPE)));
88 ITYPE *const b = (ITYPE*)libxsmm_malloc((size_t)(ldb * kb * sizeof(ITYPE)));
89 OTYPE *const c = (OTYPE*)libxsmm_malloc((size_t)(ldc * nn * sizeof(OTYPE)));
90 #if defined(CHECK)
91 OTYPE* d = 0;
92 if (!LIBXSMM_FEQ(0, check)) {
93 d = (OTYPE*)libxsmm_malloc((size_t)(ldc * nn * sizeof(OTYPE)));
94 LIBXSMM_MATINIT_OMP(OTYPE, 0, d, m, n, ldc, 1.0);
95 }
96 #endif
97 LIBXSMM_MATINIT_OMP(OTYPE, 0, c, m, n, ldc, 1.0);
98 LIBXSMM_MATINIT_OMP(ITYPE, 42, a, mm, ka, lda, 1.0);
99 LIBXSMM_MATINIT_OMP(ITYPE, 24, b, kk, kb, ldb, 1.0);
100 #if defined(MKL_ENABLE_AVX512)
101 mkl_enable_instructions(MKL_ENABLE_AVX512);
102 #endif
103 /* warm-up OpenMP (populate thread pool) */
104 #if defined(CHECK) && (!defined(__BLAS) || (0 != __BLAS))
105 if (0 != d) XGEMM_GOLD(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, d, &ldc);
106 #endif
107 XGEMM(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc);
108 libxsmm_gemm_print(stdout, LIBXSMM_GEMM_PRECISION(ITYPE),
109 &transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc);
110 fprintf(stdout, "\n\n");
111
112 if (0 == tasks) { /* tiled xGEMM (with library-internal parallelization) */
113 int i; double duration;
114 unsigned long long start = libxsmm_timer_tick();
115 for (i = 0; i < nrepeat; ++i) {
116 XGEMM(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc);
117 }
118 duration = libxsmm_timer_duration(start, libxsmm_timer_tick());
119 if (0 < duration) {
120 fprintf(stdout, "\tLIBXSMM: %.1f GFLOPS/s\n", gflops * nrepeat / duration);
121 }
122 }
123 else { /* tiled xGEMM (with external parallelization) */
124 int i; double duration;
125 unsigned long long start = libxsmm_timer_tick();
126 for (i = 0; i < nrepeat; ++i) {
127 #if defined(_OPENMP)
128 # pragma omp parallel
129 # pragma omp single nowait
130 #endif
131 XGEMM(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc);
132 }
133 duration = libxsmm_timer_duration(start, libxsmm_timer_tick());
134 if (0 < duration) {
135 fprintf(stdout, "\tLIBXSMM: %.1f GFLOPS/s\n", gflops * nrepeat / duration);
136 }
137 }
138 #if defined(CHECK) && (!defined(__BLAS) || (0 != __BLAS))
139 if (0 != d) { /* validate result against LAPACK/BLAS xGEMM */
140 libxsmm_matdiff_info diff;
141 int i; double duration;
142 unsigned long long start = libxsmm_timer_tick();
143 for (i = 0; i < nrepeat; ++i) {
144 XGEMM_GOLD(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, d, &ldc);
145 }
146 duration = libxsmm_timer_duration(start, libxsmm_timer_tick());
147
148 if (0 < duration) {
149 fprintf(stdout, "\tBLAS: %.1f GFLOPS/s\n", gflops * nrepeat / duration);
150 }
151 result = libxsmm_matdiff(&diff, LIBXSMM_DATATYPE(OTYPE), m, n, d, c, &ldc, &ldc);
152 if (EXIT_SUCCESS == result) {
153 fprintf(stdout, "\tdiff: L2abs=%f Linf=%f\n", diff.l2_abs, diff.linf_abs);
154 if (check < diff.l2_rel) {
155 fprintf(stderr, "FAILED.\n");
156 result = EXIT_FAILURE;
157 }
158 }
159 libxsmm_free(d);
160 }
161 #endif
162 libxsmm_free(c);
163 libxsmm_free(a);
164 libxsmm_free(b);
165 }
166 fprintf(stdout, "Finished\n");
167 return result;
168 }
169
170