1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Hans Pabst (Intel Corp.)
10 ******************************************************************************/
11 
12 /** This sample uses LIBXSMM's header-only implementation. */
13 #include <libxsmm_source.h>
14 
15 #if !defined(USE_LIBXSMM)
16 # define USE_LIBXSMM
17 #endif
18 
19 #if defined(USE_LIBXSMM)
20 # if !defined(EIGEN_VECTORIZE_AVX)
21 #   define EIGEN_VECTORIZE_AVX
22 # endif
23 # if !defined(EIGEN_USE_LIBXSMM)
24 #   define EIGEN_USE_LIBXSMM
25 # endif
26 #endif
27 
28 #if !defined(__EIGEN) && !defined(__EIGEN_UNSUPPORTED) && 0
29 # define __EIGEN_UNSUPPORTED
30 # define __EIGEN
31 #endif
32 
33 #if !defined(EIGEN_USE_THREADS) && defined(__EIGEN) && (defined(_OPENMP) \
34  || !defined(__BLAS) || (defined(__BLAS) && 1 < (__BLAS)))
35 # define EIGEN_USE_THREADS
36 #endif
37 
38 #if defined(LIBXSMM_OFFLOAD_TARGET)
39 # pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
40 #endif
41 #if defined(__EIGEN_UNSUPPORTED)
42 # include <unsupported/Eigen/CXX11/Tensor>
43 # include <unsupported/Eigen/CXX11/ThreadPool>
44 #endif
45 #include <algorithm>
46 #include <stdexcept>
47 #include <iostream>
48 #include <cstdlib>
49 #include <cstdio>
50 #if defined(LIBXSMM_OFFLOAD_TARGET)
51 # pragma offload_attribute(pop)
52 #endif
53 
54 #if !defined(ITYPE)
55 # define ITYPE float
56 #endif
57 
58 #if !defined(CHECK) && (LIBXSMM_EQUAL(ITYPE, float) || LIBXSMM_EQUAL(ITYPE, double))
59 # if !defined(MKL_DIRECT_CALL_SEQ) && !defined(MKL_DIRECT_CALL)
LIBXSMM_BLAS_SYMBOL_DECL(ITYPE,gemm)60 LIBXSMM_BLAS_SYMBOL_DECL(ITYPE, gemm)
61 # endif
62 # define CHECK
63 #endif
64 
65 
66 int main(int argc, char* argv[])
67 {
68   int result = EXIT_SUCCESS;
69   try {
70 #if !defined(__EIGEN_UNSUPPORTED)
71     LIBXSMM_UNUSED(argc); LIBXSMM_UNUSED(argv);
72     throw std::runtime_error("Eigen or Eigen/unsupported not found!");
73 #else
74     LIBXSMM_BLAS_CONST libxsmm_blasint m = (1 < argc ? std::atoi(argv[1]) : 512);
75     LIBXSMM_BLAS_CONST libxsmm_blasint k = (3 < argc ? atoi(argv[3]) : m);
76     LIBXSMM_BLAS_CONST libxsmm_blasint n = (2 < argc ? atoi(argv[2]) : k);
77     const int nrepeat = LIBXSMM_MAX(4 < argc ? atoi(argv[4]) : 13 / LIBXSMM_MAX(1, libxsmm_icbrt_u64(1ULL * m * n * k) >> 10), 3);
78 # if defined(CHECK) && (!defined(__BLAS) || (0 != __BLAS))
79     const double env_check = (0 == getenv("CHECK") ? 1.0 : atof(getenv("CHECK")));
80     const double check = LIBXSMM_ABS(env_check);
81 # endif
82     const double gflops = 2.0 * m * n * k * 1E-9;
83     const int max_nthreads = Eigen::nbThreads();
84     const int env_nthreads = 0 == getenv("NTHREADS") ? max_nthreads : atoi(getenv("NTHREADS"));
85     const int nthreads = LIBXSMM_CLMP(env_nthreads, 1, max_nthreads);
86 # if defined(LIBXSMM_OFFLOAD_TARGET)
87 #   pragma offload target(LIBXSMM_OFFLOAD_TARGET)
88 # endif
89     {
90       Eigen::ThreadPool threadpool(nthreads);
91       Eigen::ThreadPoolDevice device(&threadpool, threadpool.NumThreads());
92       Eigen::Tensor<ITYPE,2/*nindices*/,0/*options*/,libxsmm_blasint> ta(m, k), tb(k, n), tc(m, n);
93       LIBXSMM_BLAS_CONST char transa = 'N', transb = 'N';
94       LIBXSMM_BLAS_CONST ITYPE alpha(1), beta(0);
95       unsigned long long start;
96       double d1;
97       {
98         std::array<Eigen::IndexPair<libxsmm_blasint>,1> product_dims = {
99           Eigen::IndexPair<libxsmm_blasint>(1, 0),
100         };
101         ta.setRandom(); tb.setRandom();
102         start = libxsmm_timer_tick();
103         for (int i = 0; i < nrepeat; ++i) {
104           tc.device(device) = ta.contract(tb, product_dims);
105         }
106         d1 = libxsmm_timer_duration(start, libxsmm_timer_tick());
107       }
108       libxsmm_gemm_print(stdout, libxsmm_gemm_precision_enum<ITYPE>::value, &transa, &transb,
109         &m, &n, &k, &alpha, ta.data(), &m, tb.data(), &k, &beta, tc.data(), &m);
110       fprintf(stdout, "\n\n");
111 # if defined(CHECK) && (!defined(__BLAS) || (0 != __BLAS))
112       Eigen::Tensor<ITYPE, 2/*nindices*/, 0/*options*/, libxsmm_blasint> td(m, n);
113       double d2;
114       {
115         start = libxsmm_timer_tick();
116         for (int i = 0; i < nrepeat; ++i) {
117           LIBXSMM_GEMM_SYMBOL(ITYPE)(&transa, &transb, &m, &n, &k,
118             &alpha, ta.data(), &m, tb.data(), &k,
119              &beta, td.data(), &m);
120         }
121         d2 = libxsmm_timer_duration(start, libxsmm_timer_tick());
122       }
123 # endif
124       if (0 < d1) {
125         fprintf(stdout, "\tEigen"
126 # if !defined(USE_LIBXSMM)
127           "+XSMM"
128 # endif
129           ": %.1f GFLOPS/s\n", gflops * nrepeat / d1);
130       }
131 # if defined(CHECK) && (!defined(__BLAS) || (0 != __BLAS))
132       if (0 < d2) {
133         fprintf(stdout, "\tBLAS: %.1f GFLOPS/s\n", gflops * nrepeat / d2);
134       }
135       libxsmm_matdiff_info diff;
136       result = libxsmm_matdiff(&diff, LIBXSMM_DATATYPE(ITYPE), m, n, td.data(), tc.data(), &m, &m);
137       if (EXIT_SUCCESS == result) {
138         fprintf(stdout, "\tdiff: L2abs=%f Linf=%f\n", diff.l2_abs, diff.linf_abs);
139         if (check < diff.l2_rel) {
140           fprintf(stderr, "FAILED.\n");
141           result = EXIT_FAILURE;
142         }
143       }
144 # endif
145     }
146     fprintf(stdout, "Finished\n");
147 #endif /*defined(__EIGEN_UNSUPPORTED)*/
148   }
149   catch(const std::exception& e) {
150     fprintf(stderr, "Error: %s\n", e.what());
151     result = EXIT_FAILURE;
152   }
153   catch(const char* message) {
154     fprintf(stderr, "Error: %s\n", message);
155     result = EXIT_FAILURE;
156   }
157   catch(...) {
158     fprintf(stderr, "Error: unknown exception caught!\n");
159     result = EXIT_FAILURE;
160   }
161 
162   return result;
163 }
164 
165