1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Hans Pabst (Intel Corp.)
10 ******************************************************************************/
11 #ifndef LIBXSMM_INTRINSICS_X86_H
12 #define LIBXSMM_INTRINSICS_X86_H
13 
14 #include "libxsmm_cpuid.h"
15 
16 /** https://github.com/intel/Immintrin-debug */
17 #if !defined(LIBXSMM_INTRINSICS_DEBUG) && 0
18 # define LIBXSMM_INTRINSICS_DEBUG
19 #endif
20 #if defined(LIBXSMM_INTRINSICS_DEBUG)
21 # include "immintrin_dbg.h"
22 # define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CPX
23 # if !defined(_mm512_undefined_epi32)
24 #   define _mm512_undefined_epi32() _mm512_set1_epi32(0)
25 # endif
26 # if !defined(_mm256_movemask_epi8)
27 # define _mm256_movemask_epi8 mm256_movemask_epi8_dbg
mm256_movemask_epi8_dbg(__m256i k)28   LIBXSMM_API_INLINE int mm256_movemask_epi8_dbg(__m256i k) {
29     unsigned char mask[32], i; int result = 0;
30     _mm256_storeu_si256((__m256i*)mask, k);
31     for (i = 0; i < 32; ++i) result |= (mask[i] >> 7) << i;
32     return result;
33   }
34 # endif
35 # if !defined(_mm512_and_epi32)
36 # define _mm512_and_epi32 mm512_and_epi32_dbg
mm512_and_epi32_dbg(__m512i a,__m512i b)37   LIBXSMM_API_INLINE __m512i mm512_and_epi32_dbg(__m512i a, __m512i b) {
38     uint32_t a16[16], b16[16]; signed char i;
39     _mm512_storeu_si512((__m512i*)a16, a);
40     _mm512_storeu_si512((__m512i*)b16, b);
41     for (i = 0; i < 16; ++i) a16[i] &= b16[i];
42     return _mm512_loadu_si512((const __m512i*)a16);
43   }
44 # endif
45 # if !defined(_mm512_or_epi32)
46 # define _mm512_or_epi32 mm512_or_epi32_dbg
mm512_or_epi32_dbg(__m512i a,__m512i b)47   LIBXSMM_API_INLINE __m512i mm512_or_epi32_dbg(__m512i a, __m512i b) {
48     uint32_t a16[16], b16[16]; signed char i;
49     _mm512_storeu_si512((__m512i*)a16, a);
50     _mm512_storeu_si512((__m512i*)b16, b);
51     for (i = 0; i < 16; ++i) a16[i] |= b16[i];
52     return _mm512_loadu_si512((const __m512i*)a16);
53   }
54 # endif
55 # if !defined(_mm512_xor_epi32)
56 # define _mm512_xor_epi32 mm512_xor_epi32_dbg
mm512_xor_epi32_dbg(__m512i a,__m512i b)57   LIBXSMM_API_INLINE __m512i mm512_xor_epi32_dbg(__m512i a, __m512i b) {
58     uint32_t a16[16], b16[16]; signed char i;
59     _mm512_storeu_si512((__m512i*)a16, a);
60     _mm512_storeu_si512((__m512i*)b16, b);
61     for (i = 0; i < 16; ++i) a16[i] ^= b16[i];
62     return _mm512_loadu_si512((const __m512i*)a16);
63   }
64 # endif
65 # if !defined(_mm512_srli_epi32_dbg) /* GCC: avoid conflict w/ built-in */
66 # undef _mm512_srli_epi32
67 # define _mm512_srli_epi32 mm512_srli_epi32_dbg
mm512_srli_epi32_dbg(__m512i a,unsigned int imm8)68   LIBXSMM_API_INLINE __m512i mm512_srli_epi32_dbg(__m512i a, unsigned int imm8) {
69     uint32_t a16[16]; signed char i;
70     _mm512_storeu_si512((__m512i*)a16, a);
71     for (i = 0; i < 16; ++i) a16[i] >>= imm8;
72     return _mm512_loadu_si512((const __m512i*)a16);
73   }
74 # endif
75 # if !defined(_mm512_slli_epi32_dbg) /* GCC: avoid conflict w/ built-in */
76 # undef _mm512_slli_epi32
77 # define _mm512_slli_epi32 mm512_slli_epi32_dbg
mm512_slli_epi32_dbg(__m512i a,unsigned int imm8)78   LIBXSMM_API_INLINE __m512i mm512_slli_epi32_dbg(__m512i a, unsigned int imm8) {
79     uint32_t a16[16]; signed char i;
80     _mm512_storeu_si512((__m512i*)a16, a);
81     for (i = 0; i < 16; ++i) a16[i] <<= imm8;
82     return _mm512_loadu_si512((const __m512i*)a16);
83   }
84 # endif
85 # if !defined(_mm512_sub_ps)
86 # define _mm512_sub_ps mm512_sub_ps_dbg
mm512_sub_ps_dbg(__m512 a,__m512 b)87   LIBXSMM_API_INLINE __m512 mm512_sub_ps_dbg(__m512 a, __m512 b) {
88     float a16[16], b16[16]; signed char i;
89     _mm512_storeu_ps((__m512*)a16, a);
90     _mm512_storeu_ps((__m512*)b16, b);
91     for (i = 0; i < 16; ++i) a16[i] -= b16[i];
92     return _mm512_loadu_ps((const __m512*)a16);
93   }
94 # endif
95 #endif
96 
97 /** Macro evaluates to LIBXSMM_ATTRIBUTE_TARGET_xxx (see below). */
98 #define LIBXSMM_ATTRIBUTE_TARGET(TARGET) LIBXSMM_CONCATENATE(LIBXSMM_ATTRIBUTE_TARGET_, TARGET)
99 
100 #if /*no intrinsics: tested with 17.x and 18.x*/(defined(__PGI) && \
101     LIBXSMM_VERSION2(19, 0) > LIBXSMM_VERSION2(__PGIC__, __PGIC_MINOR__)) \
102  || /*legacy*/(defined(_CRAYC) && !defined(__GNUC__))
103 # if !defined(LIBXSMM_INTRINSICS_NONE) && !defined(LIBXSMM_INTRINSICS_STATIC)
104 #   define LIBXSMM_INTRINSICS_NONE
105 # endif
106 #elif !defined(LIBXSMM_INTRINSICS_STATIC) && !defined(LIBXSMM_INTRINSICS_NONE) && ( \
107       (defined(__GNUC__) && !defined(__clang__) && !defined(LIBXSMM_INTEL_COMPILER) && !defined(_CRAYC) && \
108         LIBXSMM_VERSION2(4, 4) > LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__)) /* GCC 4.4 (target-attribute) */ \
109    || (defined(__clang__) && LIBXSMM_VERSION2(3, 7) > LIBXSMM_VERSION2(__clang_major__, __clang_minor__)) \
110    || (defined(__APPLE__) && defined(__MACH__) && !defined(LIBXSMM_INTEL_COMPILER) && defined(__clang__) && \
111         LIBXSMM_VERSION2(9, 0) > LIBXSMM_VERSION2(__clang_major__, __clang_minor__)))
112 # define LIBXSMM_INTRINSICS_STATIC
113 #endif
114 
115 #if defined(LIBXSMM_OFFLOAD_TARGET)
116 # pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
117 #endif
118 
119 #if defined(__MIC__) && !defined(LIBXSMM_INTRINSICS_NONE)
120 # if !defined(LIBXSMM_STATIC_TARGET_ARCH)
121 #   define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_TARGET_ARCH_GENERIC
122 # endif
123 # define LIBXSMM_INTRINSICS(TARGET)
124 # define LIBXSMM_INTRINSICS_INCLUDE
125 #elif !defined(LIBXSMM_INTRINSICS_NONE) /*!defined(__MIC__)*/
126 # if    defined(__AVX512F__)  && defined(__AVX512CD__) \
127    &&   defined(__AVX512DQ__) && defined(__AVX512BW__) && defined(__AVX512VL__) && defined(__AVX512VNNI__) && defined(__AVX512BF16__) \
128    &&   defined(__AVX2__) && defined(__FMA__) && defined(__AVX__) && defined(__SSE4_2__) && defined(__SSE4_1__) && defined(__SSE3__) \
129    && (!defined(__GNUC__)  || defined(__clang__) || defined(LIBXSMM_INTEL_COMPILER) || defined(_CRAYC) /* TODO: check GCC, Clang, etc. */ \
130                            || (LIBXSMM_VERSION2(10, 0) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__))) \
131    && (!defined(__clang__) || (LIBXSMM_VERSION2( 9, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))) \
132    && (!defined(__APPLE__) || !defined(__MACH__) || LIBXSMM_VERSION2(99, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))
133 #   if !defined(LIBXSMM_STATIC_TARGET_ARCH)
134 #     define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CPX
135 #   endif
136 #   define LIBXSMM_INTRINSICS_INCLUDE
137 # elif  defined(__AVX512F__)  && defined(__AVX512CD__) \
138    &&   defined(__AVX512DQ__) && defined(__AVX512BW__) && defined(__AVX512VL__) && defined(__AVX512VNNI__) \
139    &&   defined(__AVX2__) && defined(__FMA__) && defined(__AVX__) && defined(__SSE4_2__) && defined(__SSE4_1__) && defined(__SSE3__) \
140    && (!defined(__GNUC__)  || defined(__clang__) || defined(LIBXSMM_INTEL_COMPILER) || defined(_CRAYC) \
141                            || (LIBXSMM_VERSION2(8, 0) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__))) \
142    && (!defined(__clang__) || (LIBXSMM_VERSION2(6, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))) \
143    && (!defined(__APPLE__) || !defined(__MACH__) || LIBXSMM_VERSION2(10, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))
144 #   if !defined(LIBXSMM_STATIC_TARGET_ARCH)
145 #     define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CLX
146 #   endif
147 #   define LIBXSMM_INTRINSICS_INCLUDE
148 # elif  defined(__AVX512F__)  && defined(__AVX512CD__) \
149    &&   defined(__AVX512DQ__) && defined(__AVX512BW__) && defined(__AVX512VL__) \
150    &&   defined(__AVX2__) && defined(__FMA__) && defined(__AVX__) && defined(__SSE4_2__) && defined(__SSE4_1__) && defined(__SSE3__) \
151    && (!defined(__GNUC__)  || defined(__clang__) || defined(LIBXSMM_INTEL_COMPILER) || defined(_CRAYC) \
152                            || (LIBXSMM_VERSION2(5, 0) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__))) \
153    && (!defined(__clang__) || (LIBXSMM_VERSION2(4, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))) \
154    && (!defined(__APPLE__) || !defined(__MACH__) || LIBXSMM_VERSION2(9, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))
155 #   if !defined(LIBXSMM_STATIC_TARGET_ARCH)
156 #     define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CORE
157 #   endif
158 #   define LIBXSMM_INTRINSICS_INCLUDE
159 # elif  defined(__AVX512F__) && defined(__AVX512CD__) \
160    &&   defined(__AVX512PF__) && defined(__AVX512ER__) \
161    &&   defined(__AVX2__) && defined(__FMA__) && defined(__AVX__) && defined(__SSE4_2__) && defined(__SSE4_1__) && defined(__SSE3__) \
162    && (!defined(__GNUC__)  || defined(__clang__) || defined(LIBXSMM_INTEL_COMPILER) || defined(_CRAYC) \
163                            || (LIBXSMM_VERSION2(5, 0) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__))) \
164    && (!defined(__clang__) || (LIBXSMM_VERSION2(4, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))) \
165    && (!defined(__APPLE__) || !defined(__MACH__) || LIBXSMM_VERSION2(9, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))
166 #   if !defined(LIBXSMM_STATIC_TARGET_ARCH)
167 #     define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_MIC
168 #   endif
169 #   define LIBXSMM_INTRINSICS_INCLUDE
170 # elif  defined(__AVX512F__) && defined(__AVX512CD__) \
171    &&   defined(__AVX2__) && defined(__FMA__) && defined(__AVX__) && defined(__SSE4_2__) && defined(__SSE4_1__) && defined(__SSE3__) \
172    && (!defined(__GNUC__)  || defined(__clang__) || defined(LIBXSMM_INTEL_COMPILER) || defined(_CRAYC) \
173                            || (LIBXSMM_VERSION2(5, 0) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__))) \
174    && (!defined(__clang__) || (LIBXSMM_VERSION2(4, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))) \
175    && (!defined(__APPLE__) || !defined(__MACH__) || LIBXSMM_VERSION2(9, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))
176 #   if !defined(LIBXSMM_STATIC_TARGET_ARCH)
177 #     define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512
178 #   endif
179 #   define LIBXSMM_INTRINSICS_INCLUDE
180 # elif defined(__AVX2__) && defined(__FMA__) && defined(__AVX__) && defined(__SSE4_2__) && defined(__SSE4_1__) && defined(__SSE3__)
181 #   if !defined(LIBXSMM_STATIC_TARGET_ARCH)
182 #     define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_X86_AVX2
183 #   endif
184 #   define LIBXSMM_INTRINSICS_INCLUDE
185 # elif defined(__AVX__) && defined(__SSE4_2__) && defined(__SSE4_1__) && defined(__SSE3__)
186 #   if !defined(LIBXSMM_STATIC_TARGET_ARCH)
187 #     define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_X86_AVX
188 #   endif
189 #   define LIBXSMM_INTRINSICS_INCLUDE
190 # elif defined(__SSE4_2__) && defined(__SSE4_1__) && defined(__SSE3__)
191 #   if !defined(LIBXSMM_STATIC_TARGET_ARCH)
192 #     define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_X86_SSE4
193 #   endif
194 #   define LIBXSMM_INTRINSICS_INCLUDE
195 # elif defined(__SSE3__)
196 #   if !defined(LIBXSMM_STATIC_TARGET_ARCH)
197 #     define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_X86_SSE3
198 #   endif
199 #   define LIBXSMM_INTRINSICS_INCLUDE
200 # elif defined(LIBXSMM_PLATFORM_X86)
201 #   if !defined(LIBXSMM_STATIC_TARGET_ARCH)
202 #     define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_X86_GENERIC
203 #   endif
204 #   if defined(__GNUC__)
205 #     define LIBXSMM_INTRINSICS_INCLUDE
206 #   endif
207 # endif
208 # if defined(LIBXSMM_STATIC_TARGET_ARCH) && !defined(LIBXSMM_INTRINSICS_STATIC)
209 #   if defined(__INTEL_COMPILER)
210       /* TODO: compiler version check for LIBXSMM_MAX_STATIC_TARGET_ARCH */
211 #     if 1904 <= (LIBXSMM_INTEL_COMPILER) && !defined(_WIN32)
212 #       define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CPX
213 #     elif 1801 <= (LIBXSMM_INTEL_COMPILER)
214 #       define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CLX
215 #     elif 1500 <= (LIBXSMM_INTEL_COMPILER)
216 #       define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CORE
217 #     elif 1400 <= (LIBXSMM_INTEL_COMPILER)
218 #       define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_MIC
219 #     else
220 #       define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX2
221 #     endif
222 #     define LIBXSMM_INTRINSICS(TARGET)/*no need for target flags*/
223 #     define LIBXSMM_INTRINSICS_INCLUDE
224 #   elif defined(_CRAYC) && defined(__GNUC__)
225       /* TODO: version check, e.g., LIBXSMM_VERSION2(11, 5) <= LIBXSMM_VERSION2(_RELEASE, _RELEASE_MINOR) */
226 #     define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX
227 #     define LIBXSMM_INTRINSICS(TARGET)/*no need for target flags*/
228 #     define LIBXSMM_INTRINSICS_INCLUDE
229 #   elif defined(_MSC_VER) && !defined(__clang__)
230       /* TODO: compiler version check for LIBXSMM_MAX_STATIC_TARGET_ARCH */
231 #     define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX2
232 #     define LIBXSMM_INTRINSICS(TARGET)/*no need for target flags*/
233 #     define LIBXSMM_INTRINSICS_INCLUDE
234 #   elif (!defined(__GNUC__)  || LIBXSMM_VERSION2(4, 9) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__)) \
235       && (!defined(__clang__) || LIBXSMM_VERSION2(4, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__)) \
236       && (!defined(__APPLE__) || !defined(__MACH__)) && !defined(__PGI) && !defined(_MSC_VER)
237 #     if defined(__CYGWIN__) && !defined(LIBXSMM_INTRINSICS_DEBUG) /* Cygwin: invalid register for .seh_savexmm */
238 #       define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX2
239 #     elif (defined(__clang__) && LIBXSMM_VERSION2(10, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))
240 #       define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CPX
241 #     elif (defined(__GNUC__)  && LIBXSMM_VERSION2(10, 0) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__)) \
242         || (defined(__clang__) && LIBXSMM_VERSION2( 9, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__) && !defined(__cray__))
243 #       define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CPX
244 #     elif (defined(__GNUC__)  && LIBXSMM_VERSION2(8, 0) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__)) \
245         || (defined(__clang__) && LIBXSMM_VERSION2(6, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))
246 #       define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CLX
247 #     elif (defined(__GNUC__)  && LIBXSMM_VERSION2(5, 0) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__)) \
248         || (defined(__clang__) && LIBXSMM_VERSION2(6, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))
249 #       define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CORE
250 #     else
251 #       define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX2
252 #     endif
253 #     define LIBXSMM_INTRINSICS_INCLUDE
254 #   else /* GCC/legacy incl. Clang */
255 #     if defined(__clang__) && !(defined(__APPLE__) && defined(__MACH__)) && !defined(_WIN32)
256 #       if (LIBXSMM_VERSION2(7, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__)) /* TODO */
257           /* no limitations */
258 #       elif (LIBXSMM_VERSION2(4, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))
259 #         if !defined(LIBXSMM_INTRINSICS_STATIC) && (LIBXSMM_STATIC_TARGET_ARCH < LIBXSMM_X86_AVX2/*workaround*/)
260 #           define LIBXSMM_INTRINSICS_STATIC
261 #         endif
262 #       elif !defined(LIBXSMM_INTRINSICS_STATIC)
263 #         define LIBXSMM_INTRINSICS_STATIC
264 #       endif
265 #       if defined(__CYGWIN__) && !defined(LIBXSMM_INTRINSICS_DEBUG) /* Cygwin: invalid register for .seh_savexmm */
266 #         define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX2
267 #       elif LIBXSMM_VERSION2(10, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__)
268 #         define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CPX
269 #       elif LIBXSMM_VERSION2( 9, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__) && !defined(__cray__)
270 #         define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CPX
271 #       elif LIBXSMM_VERSION2( 6, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__)
272 #         define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CLX
273 #       else
274 #         define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CORE
275 #       endif
276 #     else /* fall-back */
277 #       define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_STATIC_TARGET_ARCH
278 #       if !defined(LIBXSMM_INTRINSICS_STATIC) && (LIBXSMM_STATIC_TARGET_ARCH < LIBXSMM_X86_AVX2/*workaround*/)
279 #         define LIBXSMM_INTRINSICS_STATIC
280 #       endif
281 #     endif
282 #     if !defined(LIBXSMM_INTRINSICS_INCLUDE) && (!defined(__PGI) || LIBXSMM_VERSION2(19, 0) <= LIBXSMM_VERSION2(__PGIC__, __PGIC_MINOR__))
283 #       define LIBXSMM_INTRINSICS_INCLUDE
284 #     endif
285 #   endif /* GCC/legacy incl. Clang */
286 #   if !defined(LIBXSMM_MAX_STATIC_TARGET_ARCH)
287 #     error "LIBXSMM_MAX_STATIC_TARGET_ARCH not defined!"
288 #   endif
289 #   if defined(LIBXSMM_INTRINSICS_INCLUDE) && !defined(LIBXSMM_INTRINSICS_NONE) && !defined(LIBXSMM_INTRINSICS_DEBUG)
290 #     include <immintrin.h>
291 #   endif /*defined(LIBXSMM_INTRINSICS_INCLUDE)*/
292 #   if !defined(LIBXSMM_INTRINSICS)
293 #     if (LIBXSMM_MAX_STATIC_TARGET_ARCH > LIBXSMM_STATIC_TARGET_ARCH)
294 #       define LIBXSMM_INTRINSICS(TARGET) LIBXSMM_ATTRIBUTE(LIBXSMM_ATTRIBUTE_TARGET(TARGET))
295         /* LIBXSMM_ATTRIBUTE_TARGET_xxx is required to literally match the CPUID (libxsmm_cpuid.h)! */
296 #       define LIBXSMM_ATTRIBUTE_TARGET_1002 target("sse2") /* LIBXSMM_X86_GENERIC (64-bit ABI) */
297 #       if (LIBXSMM_X86_SSE3 <= LIBXSMM_MAX_STATIC_TARGET_ARCH)
298 #         define LIBXSMM_ATTRIBUTE_TARGET_1003 target("sse3")
299 #       else
300 #         define LIBXSMM_ATTRIBUTE_TARGET_1003 LIBXSMM_ATTRIBUTE_TARGET_1002
301 #       endif
302 #       if (LIBXSMM_X86_SSE4 <= LIBXSMM_MAX_STATIC_TARGET_ARCH)
303 #         define LIBXSMM_ATTRIBUTE_TARGET_1004 target("sse4.1,sse4.2")
304 #       else
305 #         define LIBXSMM_ATTRIBUTE_TARGET_1004 LIBXSMM_ATTRIBUTE_TARGET_1003
306 #       endif
307 #       if (LIBXSMM_X86_AVX <= LIBXSMM_MAX_STATIC_TARGET_ARCH)
308 #         define LIBXSMM_ATTRIBUTE_TARGET_1005 target("avx")
309 #       else
310 #         define LIBXSMM_ATTRIBUTE_TARGET_1005 LIBXSMM_ATTRIBUTE_TARGET_1004
311 #       endif
312 #       if (LIBXSMM_X86_AVX2 <= LIBXSMM_MAX_STATIC_TARGET_ARCH)
313 #         define LIBXSMM_ATTRIBUTE_TARGET_1006 target("avx2,fma")
314 #       else
315 #         define LIBXSMM_ATTRIBUTE_TARGET_1006 LIBXSMM_ATTRIBUTE_TARGET_1005
316 #       endif
317 #       if (LIBXSMM_X86_AVX512 <= LIBXSMM_MAX_STATIC_TARGET_ARCH)
318 #         define LIBXSMM_ATTRIBUTE_TARGET_1007 target("avx2,fma,avx512f,avx512cd")
319 #       else
320 #         define LIBXSMM_ATTRIBUTE_TARGET_1007 LIBXSMM_ATTRIBUTE_TARGET_1006
321 #       endif
322 #       if (LIBXSMM_X86_AVX512_MIC <= LIBXSMM_MAX_STATIC_TARGET_ARCH)
323 #         define LIBXSMM_ATTRIBUTE_TARGET_1010 target("avx2,fma,avx512f,avx512cd,avx512pf,avx512er")
324 #       else /* LIBXSMM_X86_AVX512 */
325 #         define LIBXSMM_ATTRIBUTE_TARGET_1010 LIBXSMM_ATTRIBUTE_TARGET_1007
326 #       endif
327 #       if (LIBXSMM_X86_AVX512_KNM <= LIBXSMM_MAX_STATIC_TARGET_ARCH)
328 #         define LIBXSMM_ATTRIBUTE_TARGET_1011 target("avx2,fma,avx512f,avx512cd,avx512pf,avx512er,avx5124vnniw,avx5124fmaps")
329 #       else /* LIBXSMM_X86_AVX512_MIC */
330 #         define LIBXSMM_ATTRIBUTE_TARGET_1011 LIBXSMM_ATTRIBUTE_TARGET_1010
331 #       endif
332 #       if (LIBXSMM_X86_AVX512_CORE <= LIBXSMM_MAX_STATIC_TARGET_ARCH)
333 #         define LIBXSMM_ATTRIBUTE_TARGET_1020 target("avx2,fma,avx512f,avx512cd,avx512dq,avx512bw,avx512vl")
334 #       else /* LIBXSMM_X86_AVX512 */
335 #         define LIBXSMM_ATTRIBUTE_TARGET_1020 LIBXSMM_ATTRIBUTE_TARGET_1007
336 #       endif
337 #       if (LIBXSMM_X86_AVX512_CLX <= LIBXSMM_MAX_STATIC_TARGET_ARCH)
338 #         define LIBXSMM_ATTRIBUTE_TARGET_1021 target("avx2,fma,avx512f,avx512cd,avx512dq,avx512bw,avx512vl,avx512vnni")
339 #       else /* LIBXSMM_X86_AVX512_CORE */
340 #         define LIBXSMM_ATTRIBUTE_TARGET_1021 LIBXSMM_ATTRIBUTE_TARGET_1020
341 #       endif
342 #       if (LIBXSMM_X86_AVX512_CPX <= LIBXSMM_MAX_STATIC_TARGET_ARCH)
343 #         define LIBXSMM_ATTRIBUTE_TARGET_1022 target("avx2,fma,avx512f,avx512cd,avx512dq,avx512bw,avx512vl,avx512vnni,avx512bf16")
344 #       else /* LIBXSMM_X86_AVX512_CORE */
345 #         define LIBXSMM_ATTRIBUTE_TARGET_1022 LIBXSMM_ATTRIBUTE_TARGET_1021
346 #       endif
347 #     else
348 #       define LIBXSMM_INTRINSICS(TARGET)/*no need for target flags*/
349 #     endif
350 #   elif !defined(LIBXSMM_INTRINSICS_TARGET)
351 #     define LIBXSMM_INTRINSICS_TARGET
352 #   endif /*!defined(LIBXSMM_INTRINSICS)*/
353 # endif /*defined(LIBXSMM_STATIC_TARGET_ARCH)*/
354 #endif /*!defined(LIBXSMM_INTRINSICS_NONE)*/
355 
356 #if !defined(LIBXSMM_STATIC_TARGET_ARCH)
357 # if !defined(LIBXSMM_INTRINSICS_NONE) && !defined(LIBXSMM_INTRINSICS_STATIC)
358 #   define LIBXSMM_INTRINSICS_NONE
359 # endif
360 # define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_TARGET_ARCH_GENERIC
361 #endif
362 
363 #if !defined(LIBXSMM_MAX_STATIC_TARGET_ARCH)
364 # define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_STATIC_TARGET_ARCH
365 #elif (LIBXSMM_MAX_STATIC_TARGET_ARCH < LIBXSMM_STATIC_TARGET_ARCH)
366 # undef LIBXSMM_MAX_STATIC_TARGET_ARCH
367 # define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_STATIC_TARGET_ARCH
368 #endif
369 
370 #if !defined(LIBXSMM_INTRINSICS)
371 # define LIBXSMM_INTRINSICS(TARGET)
372 #endif
373 
374 /** Include basic x86 intrinsics such as __rdtsc. */
375 #if defined(LIBXSMM_INTRINSICS_INCLUDE) && !defined(LIBXSMM_INTRINSICS_DEBUG)
376 # if defined(_WIN32)
377 #   include <intrin.h>
378 # elif defined(LIBXSMM_INTEL_COMPILER) || defined(_CRAYC) || defined(__clang__) || defined(__PGI)
379 #   include <x86intrin.h>
380 # elif defined(__GNUC__) && (LIBXSMM_VERSION2(4, 4) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__))
381 #   include <x86intrin.h>
382 # endif
383 # include <xmmintrin.h>
384 # if defined(__SSE3__)
385 #   include <pmmintrin.h>
386 # endif
387 #endif
388 
389 #if !defined(LIBXSMM_INTRINSICS_NONE)
390 # if defined(_WIN32)
391 #   include <malloc.h>
392 # else
393 #   include <mm_malloc.h>
394 # endif
395 #endif
396 
397 /**
398  * Intrinsic-specific fix-ups
399  */
400 #if defined(__clang__)
401 # define LIBXSMM_INTRINSICS_LDDQU_SI128(A) _mm_loadu_si128(A)
402 #else
403 # define LIBXSMM_INTRINSICS_LDDQU_SI128(A) _mm_lddqu_si128(A)
404 #endif
405 #if !defined(LIBXSMM_INTEL_COMPILER) && defined(__clang__) && ( \
406       (LIBXSMM_VERSION2(3, 9) > LIBXSMM_VERSION2(__clang_major__, __clang_minor__)) \
407    || (LIBXSMM_VERSION2(7, 3) > LIBXSMM_VERSION2(__clang_major__, __clang_minor__) && \
408        defined(__APPLE__) && defined(__MACH__)))
409 /* prototypes with incorrect signature: _mm512_load_ps takes DP*, _mm512_load_pd takes SP* (checked with v3.8.1) */
410 # define LIBXSMM_INTRINSICS_MM512_LOAD_PS(A) _mm512_loadu_ps((const double*)(A))
411 # define LIBXSMM_INTRINSICS_MM512_LOAD_PD(A) _mm512_loadu_pd((const float*)(A))
412 /* Clang misses _mm512_stream_p? (checked with v3.8.1). */
413 # define LIBXSMM_INTRINSICS_MM512_STREAM_SI512(A, B) _mm512_store_si512(A, B)
414 # define LIBXSMM_INTRINSICS_MM512_STREAM_PS(A, B) _mm512_storeu_ps(A, B)
415 # define LIBXSMM_INTRINSICS_MM512_STREAM_PD(A, B) _mm512_store_pd(A, B)
416 #else
417 # define LIBXSMM_INTRINSICS_MM512_LOAD_PS(A) _mm512_loadu_ps((const float*)(A))
418 # define LIBXSMM_INTRINSICS_MM512_LOAD_PD(A) _mm512_loadu_pd((const double*)(A))
419 # define LIBXSMM_INTRINSICS_MM512_STREAM_SI512(A, B) _mm512_stream_si512((__m512i*)(A), (B))
420 # define LIBXSMM_INTRINSICS_MM512_STREAM_PS(A, B) _mm512_stream_ps(A, B)
421 # define LIBXSMM_INTRINSICS_MM512_STREAM_PD(A, B) _mm512_stream_pd(A, B)
422 #endif
423 #if !defined(LIBXSMM_INTEL_COMPILER) || (defined(__clang__) && ( \
424       (LIBXSMM_VERSION2(8, 0) > LIBXSMM_VERSION2(__clang_major__, __clang_minor__)))) \
425    || (defined(__APPLE__) && defined(__MACH__)) || defined(__GNUC__)
426 # define LIBXSMM_INTRINSICS_MM256_STORE_EPI32(A, B) _mm256_storeu_si256((__m256i*)(A), B)
427 #else
428 # define LIBXSMM_INTRINSICS_MM256_STORE_EPI32(A, B) _mm256_storeu_epi32(A, B)
429 #endif
430 #if defined(LIBXSMM_INTEL_COMPILER)
431 # if 1600 <= (LIBXSMM_INTEL_COMPILER)
432 #   define LIBXSMM_INTRINSICS_MM512_SET_EPI16(E31, E30, E29, E28, E27, E26, E25, E24, E23, E22, E21, E20, E19, E18, E17, E16, \
433                                                         E15, E14, E13, E12, E11, E10, E9, E8, E7, E6, E5, E4, E3, E2, E1, E0) \
434                              _mm512_set_epi16(E31, E30, E29, E28, E27, E26, E25, E24, E23, E22, E21, E20, E19, E18, E17, E16, \
435                                                         E15, E14, E13, E12, E11, E10, E9, E8, E7, E6, E5, E4, E3, E2, E1, E0)
436 # else
437 #   define LIBXSMM_INTRINSICS_MM512_SET_EPI16(E31, E30, E29, E28, E27, E26, E25, E24, E23, E22, E21, E20, E19, E18, E17, E16, \
438                                                         E15, E14, E13, E12, E11, E10, E9, E8, E7, E6, E5, E4, E3, E2, E1, E0) \
439          _mm512_castps_si512(_mm512_set_epi16(E31, E30, E29, E28, E27, E26, E25, E24, E23, E22, E21, E20, E19, E18, E17, E16, \
440                                                         E15, E14, E13, E12, E11, E10, E9, E8, E7, E6, E5, E4, E3, E2, E1, E0))
441 # endif
442 #else
443 # define LIBXSMM_INTRINSICS_MM512_SET_EPI16(E31, E30, E29, E28, E27, E26, E25, E24, E23, E22, E21, E20, E19, E18, E17, E16, \
444                                                       E15, E14, E13, E12, E11, E10, E9, E8, E7, E6, E5, E4, E3, E2, E1, E0) \
445                _mm512_set_epi32(((E31) << 16) | (E30), ((E29) << 16) | (E28), ((E27) << 16) | (E26), ((E25) << 16) | (E24), \
446                                 ((E23) << 16) | (E22), ((E21) << 16) | (E20), ((E19) << 16) | (E18), ((E17) << 16) | (E16), \
447                                 ((E15) << 16) | (E14), ((E13) << 16) | (E12), ((E11) << 16) | (E10),  ((E9) << 16) |  (E8), \
448                                  ((E7) << 16) |  (E6),  ((E5) << 16) |  (E4),  ((E3) << 16) |  (E2),  ((E1) << 16) |  (E0))
449 #endif
450 #if defined(LIBXSMM_INTEL_COMPILER) \
451   || (defined(__GNUC__) && LIBXSMM_VERSION2(7, 0) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__)) \
452   || (defined(__clang__) && (!defined(__APPLE__) || !defined(__MACH__)) \
453       && LIBXSMM_VERSION2(4, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))
454 # define LIBXSMM_INTRINSICS_MM512_MASK_I32GATHER_EPI32(A, B, C, D, E) _mm512_mask_i32gather_epi32(A, B, C, D, E)
455 # define LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(A, B) _mm512_extracti64x4_epi64(A, B)
456 # define LIBXSMM_INTRINSICS_MM512_ABS_PS(A) _mm512_abs_ps(A)
457 # define LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32() _mm512_undefined_epi32()
458 # define LIBXSMM_INTRINSICS_MM512_UNDEFINED() _mm512_undefined()
459 # define LIBXSMM_INTRINSICS_MM_UNDEFINED_PD() _mm_undefined_pd()
460 #else
461 # define LIBXSMM_INTRINSICS_MM512_MASK_I32GATHER_EPI32(A, B, C, D, E) _mm512_castps_si512(_mm512_mask_i32gather_ps( \
462                            _mm512_castsi512_ps(A), B, C, (const float*)(D), E))
463 # define LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(A, B) _mm256_castpd_si256(_mm512_extractf64x4_pd(_mm512_castsi512_pd(A), B))
464 # define LIBXSMM_INTRINSICS_MM512_ABS_PS(A) _mm512_castsi512_ps(_mm512_and_epi32( \
465                            _mm512_castps_si512(A), _mm512_set1_epi32(0x7FFFFFFF)))
466 # define LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32() _mm512_set1_epi32(0)
467 # define LIBXSMM_INTRINSICS_MM512_UNDEFINED() _mm512_set1_ps(0)
468 # define LIBXSMM_INTRINSICS_MM_UNDEFINED_PD() _mm_set1_pd(0)
469 #endif
470 #if (defined(LIBXSMM_INTEL_COMPILER) && (1800 <= (LIBXSMM_INTEL_COMPILER))) \
471   || (!defined(LIBXSMM_INTEL_COMPILER) && defined(__GNUC__) \
472       && LIBXSMM_VERSION2(7, 0) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__)) \
473   || ((!defined(__APPLE__) || !defined(__MACH__)) && defined(__clang__) \
474       && LIBXSMM_VERSION2(8, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))
475 # define LIBXSMM_INTRINSICS_MM512_STORE_MASK(DST_PTR, SRC, NBITS) \
476     LIBXSMM_CONCATENATE(_store_mask, NBITS)((LIBXSMM_CONCATENATE(__mmask, NBITS)*)(DST_PTR), SRC)
477 # define LIBXSMM_INTRINSICS_MM512_LOAD_MASK(SRC_PTR, NBITS) \
478     LIBXSMM_CONCATENATE(_load_mask, NBITS)((/*const*/ LIBXSMM_CONCATENATE(__mmask, NBITS)*)(SRC_PTR))
479 # define LIBXSMM_INTRINSICS_MM512_CVTU32_MASK(A, NBITS) LIBXSMM_CONCATENATE(_cvtu32_mask, NBITS)((unsigned int)(A))
480 #elif defined(LIBXSMM_INTEL_COMPILER)
481 # define LIBXSMM_INTRINSICS_MM512_STORE_MASK(DST_PTR, SRC, NBITS) \
482     (*(LIBXSMM_CONCATENATE(__mmask, NBITS)*)(DST_PTR) = (LIBXSMM_CONCATENATE(__mmask, NBITS))(SRC))
483 # define LIBXSMM_INTRINSICS_MM512_LOAD_MASK(SRC_PTR, NBITS) \
484     ((LIBXSMM_CONCATENATE(__mmask, NBITS))_mm512_mask2int(*(const __mmask16*)(SRC_PTR)))
485 # define LIBXSMM_INTRINSICS_MM512_CVTU32_MASK(A, NBITS) LIBXSMM_CONCATENATE(LIBXSMM_INTRINSICS_MM512_CVTU32_MASK_, NBITS)(A)
486 # define LIBXSMM_INTRINSICS_MM512_CVTU32_MASK_32(A) ((__mmask32)(A))
487 # define LIBXSMM_INTRINSICS_MM512_CVTU32_MASK_16(A) _mm512_int2mask((int)(A))
488 # define LIBXSMM_INTRINSICS_MM512_CVTU32_MASK_8(A) ((__mmask8)(A))
489 #else
490 # define LIBXSMM_INTRINSICS_MM512_STORE_MASK(DST_PTR, SRC, NBITS) \
491     (*(LIBXSMM_CONCATENATE(__mmask, NBITS)*)(DST_PTR) = (LIBXSMM_CONCATENATE(__mmask, NBITS))(SRC))
492 # define LIBXSMM_INTRINSICS_MM512_LOAD_MASK(SRC_PTR, NBITS) (*(const LIBXSMM_CONCATENATE(__mmask, NBITS)*)(SRC_PTR))
493 # define LIBXSMM_INTRINSICS_MM512_CVTU32_MASK(A, NBITS) ((LIBXSMM_CONCATENATE(__mmask, NBITS))(A))
494 #endif
495 #define LIBXSMM_INTRINSICS_MM512_STORE_MASK64(DST_PTR, SRC) LIBXSMM_INTRINSICS_MM512_STORE_MASK(DST_PTR, SRC, 64)
496 #define LIBXSMM_INTRINSICS_MM512_STORE_MASK32(DST_PTR, SRC) LIBXSMM_INTRINSICS_MM512_STORE_MASK(DST_PTR, SRC, 32)
497 #define LIBXSMM_INTRINSICS_MM512_STORE_MASK16(DST_PTR, SRC) LIBXSMM_INTRINSICS_MM512_STORE_MASK(DST_PTR, SRC, 16)
498 #define LIBXSMM_INTRINSICS_MM512_STORE_MASK8(DST_PTR, SRC) LIBXSMM_INTRINSICS_MM512_STORE_MASK(DST_PTR, SRC, 8)
499 #define LIBXSMM_INTRINSICS_MM512_LOAD_MASK64(SRC_PTR) LIBXSMM_INTRINSICS_MM512_LOAD_MASK(SRC_PTR, 64)
500 #define LIBXSMM_INTRINSICS_MM512_LOAD_MASK32(SRC_PTR) LIBXSMM_INTRINSICS_MM512_LOAD_MASK(SRC_PTR, 32)
501 #define LIBXSMM_INTRINSICS_MM512_LOAD_MASK16(SRC_PTR) LIBXSMM_INTRINSICS_MM512_LOAD_MASK(SRC_PTR, 16)
502 #define LIBXSMM_INTRINSICS_MM512_LOAD_MASK8(SRC_PTR) LIBXSMM_INTRINSICS_MM512_LOAD_MASK(SRC_PTR, 8)
503 #define LIBXSMM_INTRINSICS_MM512_CVTU32_MASK32(A) LIBXSMM_INTRINSICS_MM512_CVTU32_MASK(A, 32)
504 #define LIBXSMM_INTRINSICS_MM512_CVTU32_MASK16(A) LIBXSMM_INTRINSICS_MM512_CVTU32_MASK(A, 16)
505 #define LIBXSMM_INTRINSICS_MM512_CVTU32_MASK8(A) LIBXSMM_INTRINSICS_MM512_CVTU32_MASK(A, 8)
506 
507 /**
508  * Pseudo intrinsics for portability
509  */
LIBXSMM_INTRINSICS_BITSCANFWD32_SW(unsigned int n)510 LIBXSMM_API_INLINE int LIBXSMM_INTRINSICS_BITSCANFWD32_SW(unsigned int n) {
511   unsigned int i, r = 0; if (0 != n) for (i = 1; 0 == (n & i); i <<= 1) { ++r; } return r;
512 }
LIBXSMM_INTRINSICS_BITSCANFWD64_SW(unsigned long long n)513 LIBXSMM_API_INLINE int LIBXSMM_INTRINSICS_BITSCANFWD64_SW(unsigned long long n) {
514   unsigned int i, r = 0; if (0 != n) for (i = 1; 0 == (n & i); i <<= 1) { ++r; } return r;
515 }
516 
517 /** Binary Logarithm (based on Stackoverflow's NBITSx macro). */
518 #define LIBXSMM_INTRINSICS_BITSCANBWD_SW02(N) (0 != ((N) & 0x2/*0b10*/) ? 1 : 0)
519 #define LIBXSMM_INTRINSICS_BITSCANBWD_SW04(N) (0 != ((N) & 0xC/*0b1100*/) ? (2 | LIBXSMM_INTRINSICS_BITSCANBWD_SW02((N) >> 2)) : LIBXSMM_INTRINSICS_BITSCANBWD_SW02(N))
520 #define LIBXSMM_INTRINSICS_BITSCANBWD_SW08(N) (0 != ((N) & 0xF0/*0b11110000*/) ? (4 | LIBXSMM_INTRINSICS_BITSCANBWD_SW04((N) >> 4)) : LIBXSMM_INTRINSICS_BITSCANBWD_SW04(N))
521 #define LIBXSMM_INTRINSICS_BITSCANBWD_SW16(N) (0 != ((N) & 0xFF00) ? (8 | LIBXSMM_INTRINSICS_BITSCANBWD_SW08((N) >> 8)) : LIBXSMM_INTRINSICS_BITSCANBWD_SW08(N))
522 #define LIBXSMM_INTRINSICS_BITSCANBWD_SW32(N) (0 != ((N) & 0xFFFF0000) ? (16 | LIBXSMM_INTRINSICS_BITSCANBWD_SW16((N) >> 16)) : LIBXSMM_INTRINSICS_BITSCANBWD_SW16(N))
523 #define LIBXSMM_INTRINSICS_BITSCANBWD_SW64(N) (0 != ((N) & 0xFFFFFFFF00000000) ? (32 | LIBXSMM_INTRINSICS_BITSCANBWD_SW32((N) >> 32)) : LIBXSMM_INTRINSICS_BITSCANBWD_SW32(N))
524 #define LIBXSMM_INTRINSICS_BITSCANBWD32_SW(N) LIBXSMM_INTRINSICS_BITSCANBWD_SW32((unsigned int)(N))
525 #define LIBXSMM_INTRINSICS_BITSCANBWD64_SW(N) LIBXSMM_INTRINSICS_BITSCANBWD_SW64((unsigned long long)(N))
526 
527 #if defined(_WIN32) && !defined(LIBXSMM_INTRINSICS_NONE)
LIBXSMM_INTRINSICS_BITSCANFWD32(unsigned int n)528   LIBXSMM_API_INLINE unsigned int LIBXSMM_INTRINSICS_BITSCANFWD32(unsigned int n) {
529     unsigned long r = 0; _BitScanForward(&r, n); return (0 != n) * r;
530   }
LIBXSMM_INTRINSICS_BITSCANBWD32(unsigned int n)531   LIBXSMM_API_INLINE unsigned int LIBXSMM_INTRINSICS_BITSCANBWD32(unsigned int n) {
532     unsigned long r = 0; _BitScanReverse(&r, n); return r;
533   }
534 # if defined(_WIN64)
LIBXSMM_INTRINSICS_BITSCANFWD64(unsigned long long n)535   LIBXSMM_API_INLINE unsigned int LIBXSMM_INTRINSICS_BITSCANFWD64(unsigned long long n) {
536     unsigned long r = 0; _BitScanForward64(&r, n); return (0 != n) * r;
537   }
LIBXSMM_INTRINSICS_BITSCANBWD64(unsigned long long n)538   LIBXSMM_API_INLINE unsigned int LIBXSMM_INTRINSICS_BITSCANBWD64(unsigned long long n) {
539     unsigned long r = 0; _BitScanReverse64(&r, n); return r;
540   }
541 # else
542 # define LIBXSMM_INTRINSICS_BITSCANFWD64 LIBXSMM_INTRINSICS_BITSCANFWD64_SW
543 # define LIBXSMM_INTRINSICS_BITSCANBWD64 LIBXSMM_INTRINSICS_BITSCANBWD64_SW
544 # endif
545 #elif defined(__GNUC__) && !defined(LIBXSMM_INTRINSICS_NONE)
546 # define LIBXSMM_INTRINSICS_BITSCANFWD32(N) ((0 != (N)) * __builtin_ctz(N))
547 # define LIBXSMM_INTRINSICS_BITSCANFWD64(N) ((0 != (N)) * __builtin_ctzll(N))
548 # define LIBXSMM_INTRINSICS_BITSCANBWD32(N) ((0 != (N)) * (31 - __builtin_clz(N)))
549 # define LIBXSMM_INTRINSICS_BITSCANBWD64(N) ((0 != (N)) * (63 - __builtin_clzll(N)))
550 #else /* fall-back implementation */
551 # define LIBXSMM_INTRINSICS_BITSCANFWD32 LIBXSMM_INTRINSICS_BITSCANFWD32_SW
552 # define LIBXSMM_INTRINSICS_BITSCANFWD64 LIBXSMM_INTRINSICS_BITSCANFWD64_SW
553 # define LIBXSMM_INTRINSICS_BITSCANBWD32 LIBXSMM_INTRINSICS_BITSCANBWD32_SW
554 # define LIBXSMM_INTRINSICS_BITSCANBWD64 LIBXSMM_INTRINSICS_BITSCANBWD64_SW
555 #endif
556 
557 /** LIBXSMM_NBITS determines the minimum number of bits needed to represent N. */
558 #define LIBXSMM_NBITS(N) (LIBXSMM_INTRINSICS_BITSCANBWD64(N) + LIBXSMM_MIN(1, N))
559 #define LIBXSMM_ISQRT2(N) ((unsigned int)((1ULL << (LIBXSMM_NBITS(N) >> 1)) /*+ LIBXSMM_MIN(1, N)*/))
560 /** LIBXSMM_ILOG2 definition matches ceil(log2(N)). */
LIBXSMM_ILOG2(unsigned long long n)561 LIBXSMM_API_INLINE unsigned int LIBXSMM_ILOG2(unsigned long long n) {
562   unsigned int result = 0; if (1 < n) {
563     const unsigned int m = LIBXSMM_INTRINSICS_BITSCANBWD64(n);
564     result = m + ((unsigned int)LIBXSMM_INTRINSICS_BITSCANBWD64(n - 1) == m);
565   } return result;
566 }
567 
568 /**
569  * Target attribution
570  */
571 #if !defined(LIBXSMM_INTRINSICS_KNC) && !defined(LIBXSMM_INTRINSICS_NONE) && defined(__MIC__)
572 # define LIBXSMM_INTRINSICS_KNC
573 #endif
574 /** LIBXSMM_INTRINSICS_X86 is defined only if the compiler is able to generate this code without special flags. */
575 #if !defined(LIBXSMM_INTRINSICS_X86) && !defined(LIBXSMM_INTRINSICS_NONE) && (LIBXSMM_X86_GENERIC <= LIBXSMM_STATIC_TARGET_ARCH || \
576    (!defined(LIBXSMM_INTRINSICS_STATIC) && LIBXSMM_X86_GENERIC <= LIBXSMM_MAX_STATIC_TARGET_ARCH))
577 # define LIBXSMM_INTRINSICS_X86
578 #endif
579 /** LIBXSMM_INTRINSICS_SSE3 is defined only if the compiler is able to generate this code without special flags. */
580 #if !defined(LIBXSMM_INTRINSICS_SSE3) && !defined(LIBXSMM_INTRINSICS_NONE) && (LIBXSMM_X86_SSE3 <= LIBXSMM_STATIC_TARGET_ARCH || \
581    (!defined(LIBXSMM_INTRINSICS_STATIC) && LIBXSMM_X86_SSE3 <= LIBXSMM_MAX_STATIC_TARGET_ARCH))
582 # define LIBXSMM_INTRINSICS_SSE3
583 #endif
584 /** LIBXSMM_INTRINSICS_SSE4 is defined only if the compiler is able to generate this code without special flags. */
585 #if !defined(LIBXSMM_INTRINSICS_SSE4) && !defined(LIBXSMM_INTRINSICS_NONE) && (LIBXSMM_X86_SSE4 <= LIBXSMM_STATIC_TARGET_ARCH || \
586    (!defined(LIBXSMM_INTRINSICS_STATIC) && LIBXSMM_X86_SSE4 <= LIBXSMM_MAX_STATIC_TARGET_ARCH))
587 # define LIBXSMM_INTRINSICS_SSE4
588 #endif
589 /** LIBXSMM_INTRINSICS_AVX is defined only if the compiler is able to generate this code without special flags. */
590 #if !defined(LIBXSMM_INTRINSICS_AVX) && !defined(LIBXSMM_INTRINSICS_NONE) && (LIBXSMM_X86_AVX <= LIBXSMM_STATIC_TARGET_ARCH || \
591    (!defined(LIBXSMM_INTRINSICS_STATIC) && LIBXSMM_X86_AVX <= LIBXSMM_MAX_STATIC_TARGET_ARCH))
592 # define LIBXSMM_INTRINSICS_AVX
593 #endif
594 /** LIBXSMM_INTRINSICS_AVX2 is defined only if the compiler is able to generate this code without special flags. */
595 #if !defined(LIBXSMM_INTRINSICS_AVX2) && !defined(LIBXSMM_INTRINSICS_NONE) && (LIBXSMM_X86_AVX2 <= LIBXSMM_STATIC_TARGET_ARCH || \
596    (!defined(LIBXSMM_INTRINSICS_STATIC) && LIBXSMM_X86_AVX2 <= LIBXSMM_MAX_STATIC_TARGET_ARCH))
597 # define LIBXSMM_INTRINSICS_AVX2
598 #endif
599 /** LIBXSMM_INTRINSICS_AVX512 is defined only if the compiler is able to generate this code without special flags. */
600 #if !defined(LIBXSMM_INTRINSICS_AVX512) && !defined(LIBXSMM_INTRINSICS_NONE) && (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH || \
601    (!defined(LIBXSMM_INTRINSICS_STATIC) && LIBXSMM_X86_AVX512 <= LIBXSMM_MAX_STATIC_TARGET_ARCH))
602 # define LIBXSMM_INTRINSICS_AVX512
603 #endif
604 /** LIBXSMM_INTRINSICS_AVX512_MIC is defined only if the compiler is able to generate this code without special flags. */
605 #if !defined(LIBXSMM_INTRINSICS_AVX512_MIC) && !defined(LIBXSMM_INTRINSICS_NONE) && (LIBXSMM_X86_AVX512_MIC <= LIBXSMM_STATIC_TARGET_ARCH || \
606    (!defined(LIBXSMM_INTRINSICS_STATIC) && LIBXSMM_X86_AVX512_MIC <= LIBXSMM_MAX_STATIC_TARGET_ARCH))
607 # define LIBXSMM_INTRINSICS_AVX512_MIC
608 #endif
609 /** LIBXSMM_INTRINSICS_AVX512_KNM is defined only if the compiler is able to generate this code without special flags. */
610 #if !defined(LIBXSMM_INTRINSICS_AVX512_KNM) && !defined(LIBXSMM_INTRINSICS_NONE) && (LIBXSMM_X86_AVX512_KNM <= LIBXSMM_STATIC_TARGET_ARCH || \
611    (!defined(LIBXSMM_INTRINSICS_STATIC) && LIBXSMM_X86_AVX512_KNM <= LIBXSMM_MAX_STATIC_TARGET_ARCH))
612 # define LIBXSMM_INTRINSICS_AVX512_KNM
613 #endif
614 /** LIBXSMM_INTRINSICS_AVX512_CORE is defined only if the compiler is able to generate this code without special flags. */
615 #if !defined(LIBXSMM_INTRINSICS_AVX512_CORE) && !defined(LIBXSMM_INTRINSICS_NONE) && (LIBXSMM_X86_AVX512_CORE <= LIBXSMM_STATIC_TARGET_ARCH || \
616    (!defined(LIBXSMM_INTRINSICS_STATIC) && LIBXSMM_X86_AVX512_CORE <= LIBXSMM_MAX_STATIC_TARGET_ARCH))
617 # define LIBXSMM_INTRINSICS_AVX512_CORE
618 #endif
619 /** LIBXSMM_INTRINSICS_AVX512_CLX is defined only if the compiler is able to generate this code without special flags. */
620 #if !defined(LIBXSMM_INTRINSICS_AVX512_CLX) && !defined(LIBXSMM_INTRINSICS_NONE) && (LIBXSMM_X86_AVX512_CLX <= LIBXSMM_STATIC_TARGET_ARCH || \
621    (!defined(LIBXSMM_INTRINSICS_STATIC) && LIBXSMM_X86_AVX512_CLX <= LIBXSMM_MAX_STATIC_TARGET_ARCH))
622 # define LIBXSMM_INTRINSICS_AVX512_CLX
623 #endif
624 /** LIBXSMM_INTRINSICS_AVX512_CPX is defined only if the compiler is able to generate this code without special flags. */
625 #if !defined(LIBXSMM_INTRINSICS_AVX512_CPX) && !defined(LIBXSMM_INTRINSICS_NONE) && defined(LIBXSMM_X86_AVX512_CPX) && \
626     !defined(LIBXSMM_INTRINSICS_STATIC) && (LIBXSMM_X86_AVX512_CPX <= LIBXSMM_MAX_STATIC_TARGET_ARCH)
627 # define LIBXSMM_INTRINSICS_AVX512_CPX
628 #endif
629 
630 /**
631  * Pseudo intrinsics (AVX-512)
632  */
633 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
634 # define LIBXSMM_INTRINSICS_MM512_QUANTIZE_NEAR_PS_EPI16( A, B ) _mm512_cvtepi32_epi16(_mm512_cvt_roundps_epi32( \
635     _mm512_mul_ps(LIBXSMM_INTRINSICS_MM512_LOAD_PS(A), B), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC))
636 
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)637 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512i LIBXSMM_INTRINSICS_MM512_ROUNDNE_BF16(__m512 a) {
638   const __m512i vnaninf = _mm512_set1_epi32(0x7f800000), vrneadd = _mm512_set1_epi32(0x00007fff);
639   const __m512i vfixup = _mm512_set1_epi32(0x00000001), vfixupmask = _mm512_set1_epi32(0x00010000);
640   const __m512i mm512_roundbf16rne_a_ = _mm512_castps_si512(a);
641   const __mmask16 mm512_roundbf16rne_mask1_ = _mm512_cmp_epi32_mask(_mm512_and_epi32(mm512_roundbf16rne_a_, vnaninf), vnaninf, _MM_CMPINT_NE);
642   const __mmask16 mm512_roundbf16rne_mask2_ = _mm512_cmp_epi32_mask(_mm512_and_epi32(mm512_roundbf16rne_a_, vfixupmask), vfixupmask, _MM_CMPINT_EQ);
643   return _mm512_mask_add_epi32(mm512_roundbf16rne_a_, mm512_roundbf16rne_mask1_, mm512_roundbf16rne_a_, _mm512_mask_add_epi32(vrneadd, mm512_roundbf16rne_mask2_, vrneadd, vfixup));
644 }
645 
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)646 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m256i LIBXSMM_INTRINSICS_MM512_CVT_FP32_BF16(__m512 a) {
647   return _mm512_cvtepi32_epi16(_mm512_srai_epi32(LIBXSMM_INTRINSICS_MM512_ROUNDNE_BF16(a), 16));
648 }
649 
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)650 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512i LIBXSMM_INTRINSICS_MM512_CVT2_FP32_BF16(__m512 a, __m512 b) {
651   const __m256i aa = _mm512_cvtepi32_epi16(_mm512_srai_epi32(LIBXSMM_INTRINSICS_MM512_ROUNDNE_BF16(b), 16));
652   const __m256i bb = _mm512_cvtepi32_epi16(_mm512_srai_epi32(LIBXSMM_INTRINSICS_MM512_ROUNDNE_BF16(a), 16));
653   return _mm512_inserti64x4(_mm512_inserti64x4(_mm512_setzero_si512(), aa, 0), bb, 1);
654 }
655 
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)656 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(__m256i a) {
657   return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepi16_epi32(a),16));
658 }
659 
660 /** SVML-intrinsics */
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)661 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_TANH_PS_RATIONAL_78( __m512 x ) {
662   const  __m512 c0        = _mm512_set1_ps(2027025.0f);
663   const  __m512 c1        = _mm512_set1_ps(270270.0f);
664   const  __m512 c2        = _mm512_set1_ps(6930.0f);
665   const  __m512 c3        = _mm512_set1_ps(36.0f);
666   const  __m512 c1_d      = _mm512_set1_ps(945945.0f);
667   const  __m512 c2_d      = _mm512_set1_ps(51975.0f);
668   const  __m512 c3_d      = _mm512_set1_ps(630.0f);
669   const  __m512 hi_bound  = _mm512_set1_ps(4.97f);
670   const  __m512 lo_bound  = _mm512_set1_ps(-4.97f);
671   const  __m512 ones      = _mm512_set1_ps(1.0f);
672   const  __m512 neg_ones  = _mm512_set1_ps(-1.0f);
673 
674   const __m512 x2         = _mm512_mul_ps( x, x );
675   const __m512 t1_nom     = _mm512_fmadd_ps( c3, x2, c2 );
676   const __m512 t2_nom     = _mm512_fmadd_ps( t1_nom, x2, c1 );
677   const __m512 t3_nom     = _mm512_fmadd_ps( t2_nom, x2, c0 );
678   const __m512 nom        = _mm512_mul_ps( t3_nom, x );
679   const __m512 t1_denom   = _mm512_add_ps( x2, c3_d );
680   const __m512 t2_denom   = _mm512_fmadd_ps( t1_denom, x2, c2_d );
681   const __m512 t3_denom   = _mm512_fmadd_ps( t2_denom, x2, c1_d );
682   const __m512 denom      = _mm512_fmadd_ps( t3_denom, x2, c0 );
683   const __m512 denom_rcp  = _mm512_rcp14_ps( denom );
684   const __mmask16 mask_hi = _mm512_cmp_ps_mask( x, hi_bound, _CMP_GT_OQ);
685   const __mmask16 mask_lo = _mm512_cmp_ps_mask( x, lo_bound, _CMP_LT_OQ);
686   __m512 result           = _mm512_mul_ps( nom, denom_rcp );
687   result                  = _mm512_mask_blend_ps(mask_hi, result, ones);
688   result                  = _mm512_mask_blend_ps(mask_lo, result, neg_ones);
689 
690   return result;
691 }
692 
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)693 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_TANH_PS_RATIONAL_32( __m512 x ) {
694   const  __m512 c1        = _mm512_set1_ps((float)(1.0/27.0));
695   const  __m512 c2        = _mm512_set1_ps((float)(1.0/3));
696   const  __m512 hi_bound  = _mm512_set1_ps(3.2f);
697   const  __m512 lo_bound  = _mm512_set1_ps(-3.2f);
698   const  __m512 ones      = _mm512_set1_ps(1.0f);
699   const  __m512 neg_ones  = _mm512_set1_ps(-1.0f);
700 
701   const __m512 x2         = _mm512_mul_ps( x, x );
702   const __m512 t1_nom     = _mm512_fmadd_ps( x2, c1, ones);
703   const __m512 nom        = _mm512_mul_ps( t1_nom, x );
704   const __m512 denom      = _mm512_fmadd_ps( x2, c2, ones);
705   const __m512 denom_rcp  = _mm512_rcp14_ps( denom );
706   const __mmask16 mask_hi = _mm512_cmp_ps_mask( x, hi_bound, _CMP_GT_OQ);
707   const __mmask16 mask_lo = _mm512_cmp_ps_mask( x, lo_bound, _CMP_LT_OQ);
708   __m512 result           = _mm512_mul_ps(nom, denom_rcp);
709   result                  = _mm512_mask_blend_ps(mask_hi, result, ones);
710   result                  = _mm512_mask_blend_ps(mask_lo, result, neg_ones);
711 
712   return result;
713 }
714 
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)715 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_TANH_PS_EXP2( __m512 _x ) {
716   const __m512 twice_log2_e = _mm512_set1_ps((float)(1.442695*2));
717   const __m512 half       = _mm512_set1_ps(0.5f);
718   const __m512 c2         = _mm512_set1_ps(0.240226507f);
719   const __m512 c1         = _mm512_set1_ps(0.452920674f);
720   const __m512 c0         = _mm512_set1_ps(0.713483036f);
721   const __m512 ones       = _mm512_set1_ps(1.0f);
722   const __m512 minus_twos = _mm512_set1_ps(-2.0f);
723 
724   const __m512 x          = _mm512_fmadd_ps(_x, twice_log2_e, half);
725 #if 1
726   const __m512 y          = _mm512_sub_ps(x, _mm512_roundscale_round_ps(x, 1, _MM_FROUND_CUR_DIRECTION));
727 #else
728   const __m512 y          = _mm512_reduce_ps(x, 1);
729 #endif
730   const __m512 t1         = _mm512_fmadd_ps( y, c2, c1);
731   const __m512 two_to_y   = _mm512_fmadd_ps( y, t1, c0);
732   const __m512 exp        = _mm512_scalef_ps( two_to_y, x );
733   const __m512 denom_rcp  = _mm512_rcp14_ps( _mm512_add_ps( exp, ones) );
734   __m512 result     = _mm512_fmadd_ps( denom_rcp, minus_twos, ones);
735 
736  return result;
737 }
738 
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)739 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_TANH_PS_EXP3( __m512 _x ) {
740   const __m512 twice_log2_e = _mm512_set1_ps((float)(1.442695*2));
741   const __m512 half       = _mm512_set1_ps(0.5f);
742   const __m512 c3         = _mm512_set1_ps(0.05550410866f);
743   const __m512 c2         = _mm512_set1_ps(0.15697034396f);
744   const __m512 c1         = _mm512_set1_ps(0.49454875509f);
745   const __m512 c0         = _mm512_set1_ps(0.70654502287f);
746   const __m512 ones       = _mm512_set1_ps(1.0f);
747   const __m512 minus_twos = _mm512_set1_ps(-2.0f);
748 
749   const __m512 x          = _mm512_fmadd_ps(_x, twice_log2_e, half);
750 #if 1
751   const __m512 y          = _mm512_sub_ps(x, _mm512_roundscale_round_ps(x, 1, _MM_FROUND_CUR_DIRECTION));
752 #else
753   const __m512 y          = _mm512_reduce_ps(x, 1);
754 #endif
755   const __m512 t1         = _mm512_fmadd_ps( y, c3, c2);
756   const __m512 t2         = _mm512_fmadd_ps( y, t1, c1);
757   const __m512 two_to_y   = _mm512_fmadd_ps( y, t2, c0);
758   const __m512 exp        = _mm512_scalef_ps( two_to_y, x );
759   const __m512 denom_rcp  = _mm512_rcp14_ps( _mm512_add_ps( exp, ones) );
760   __m512 result     = _mm512_fmadd_ps( denom_rcp, minus_twos, ones);
761 
762   return result;
763 }
764 
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)765 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2( __m512 x ) {
766   __m512 result, func_p0, func_p1, func_p2;
767   const __m512i sign_mask = _mm512_set1_epi32( 0x80000000 );
768   const __m512i sign_filter = _mm512_set1_epi32( 0x7FFFFFFF );
769   const __m512i lut_low = _mm512_set1_epi32( 246 );
770   const __m512i lut_high = _mm512_set1_epi32( 261 );
771   const __m512 tanh_p0_2_reg = _mm512_set_ps( 0.40555000f,  0.11892800f, -0.00972979f, -0.02740300f, -0.0169851f, -0.00776152f, -0.00305889f,
772                                              -0.00116259f, -0.00041726f, -8.53233e-6f,  1.0000000f,  0.99999800f,  0.99975400f,  0.99268200f,
773                                               0.93645300f,  0.73833900f);
774   const __m512 tanh_p1_2_reg = _mm512_set_ps( 0.495602f, 0.88152f, 1.125700000f, 1.17021000f, 1.1289000000f, 1.07929000f, 1.0432300f, 1.023010f,
775                                               1.011620f, 1.00164f, 1.56828e-14f, 4.49924e-7f, 0.0000646924f, 0.00260405f, 0.0311608f, 0.168736f);
776   const __m512 tanh_p2_2_reg = _mm512_set_ps(-0.108238f, -0.2384280f, -0.354418000f, -0.38240300f, -0.34135700f, -0.274509000f, -0.20524900f, -0.1511960f,
777                                              -0.107635f, -0.0466868f, -3.60822e-16f, -2.05971e-8f, -4.24538e-6f, -0.000231709f, -0.00386434f, -0.0277702f);
778 
779   const __m512i signs   = _mm512_and_epi32(_mm512_castps_si512(x), sign_mask);
780   const __m512i abs_arg = _mm512_and_epi32(_mm512_castps_si512(x), sign_filter);
781   __m512i indices       = _mm512_srli_epi32(abs_arg, 22);
782   indices               = _mm512_max_epi32(indices, lut_low);
783   indices               = _mm512_min_epi32(indices, lut_high);
784 
785   func_p0               = _mm512_permutexvar_ps(indices, tanh_p0_2_reg);
786   func_p1               = _mm512_permutexvar_ps(indices, tanh_p1_2_reg);
787   func_p2               = _mm512_permutexvar_ps(indices, tanh_p2_2_reg);
788 
789   result                = _mm512_fmadd_ps(_mm512_castsi512_ps(abs_arg), func_p2, func_p1);
790   result                = _mm512_fmadd_ps(_mm512_castsi512_ps(abs_arg), result, func_p0);
791   result                = _mm512_castsi512_ps(_mm512_xor_epi32(_mm512_castps_si512(result), signs));
792 
793   return result;
794 }
795 
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)796 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX3( __m512 x ) {
797   __m512 result, func_p0, func_p1, func_p2, func_p3;
798   const __m512i sign_mask = _mm512_set1_epi32( 0x80000000 );
799   const __m512i sign_filter = _mm512_set1_epi32( 0x7FFFFFFF );
800   const __m512i lut_low = _mm512_set1_epi32( 246 );
801   const __m512i lut_high = _mm512_set1_epi32( 261 );
802 
803   const __m512 tanh_p0_3_reg = _mm512_setr_ps( 0.466283000f,  0.82850600f,  0.97437500f,  0.99882600f,  0.9999860f,  1.0000000f, -1.50006e-08f, -7.98169e-06f,
804                                               -4.53753e-05f, -0.00023755f, -0.00125285f, -0.00572314f, -0.0227717f, -0.0629089f, -0.084234300f,  0.071199800f);
805   const __m512 tanh_p1_3_reg = _mm512_setr_ps( 0.500617f, 0.124369f, 0.0137214f, 0.000464124f, 4.02465e-06f, 0.00000f, 1.00001f, 1.00028f, 1.00112f, 1.00414f,
806                                                1.015570f, 1.050950f, 1.1478500f, 1.310130000f, 1.378950000f, 1.07407f);
807   const __m512 tanh_p2_3_reg = _mm512_setr_ps(-0.16133200f, -0.0305526f, -0.00245909f, -6.12647e-05f, -3.76127e-07f,  0.000000f, -0.000245872f, -0.00341151f,
808                                               -0.00971505f, -0.0256817f, -0.06869110f, -0.162433000f, -0.346828000f, -0.566516f, -0.640214000f, -0.44011900f);
809   const __m512 tanh_p3_3_reg = _mm512_setr_ps( 0.0177393f,  0.00253432f,  0.000147303f,  2.69963e-06f, 1.16764e-08f, 0.0000000f, -0.330125f, -0.3176210f,
810                                               -0.3017760f, -0.27358000f, -0.219375000f, -0.136197000f, -0.01868680f, 0.0808901f,  0.107095f,  0.0631459f);
811 
812   const __m512i signs   = _mm512_and_epi32(_mm512_castps_si512(x), sign_mask);
813   const __m512i abs_arg = _mm512_and_epi32(_mm512_castps_si512(x), sign_filter);
814   __m512i indices       = _mm512_srli_epi32(abs_arg, 22);
815   indices               = _mm512_max_epi32(indices, lut_low);
816   indices               = _mm512_min_epi32(indices, lut_high);
817 
818   func_p0               = _mm512_permutexvar_ps(indices, tanh_p0_3_reg);
819   func_p1               = _mm512_permutexvar_ps(indices, tanh_p1_3_reg);
820   func_p2               = _mm512_permutexvar_ps(indices, tanh_p2_3_reg);
821   func_p3               = _mm512_permutexvar_ps(indices, tanh_p3_3_reg);
822 
823   result                = _mm512_fmadd_ps(_mm512_castsi512_ps(abs_arg), func_p3, func_p2);
824   result                = _mm512_fmadd_ps(_mm512_castsi512_ps(abs_arg), result, func_p1);
825   result                = _mm512_fmadd_ps(_mm512_castsi512_ps(abs_arg), result, func_p0);
826   result                = _mm512_castsi512_ps(_mm512_xor_epi32(_mm512_castps_si512(result), signs));
827 
828   return result;
829 }
830 
831 #if defined(LIBXSMM_INTEL_COMPILER)
832 # define LIBXSMM_INTRINSICS_MM512_TANH_PS(A) _mm512_tanh_ps(A)
833 # define LIBXSMM_INTRINSICS_MM512_EXP_PS(A) _mm512_exp_ps(A)
834 #else
835 # if !defined(LIBXSMM_NO_LIBM)
836 #   include <math.h>
837 # endif
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)838 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_TANH_PS(__m512 a) {
839   float a16[16]; int i;
840   _mm512_storeu_ps(a16, a);
841   for (i = 0; i < 16; ++i) a16[i] = LIBXSMM_TANHF(a16[i]);
842   return _mm512_loadu_ps(a16);
843 }
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)844 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_EXP_PS(__m512 a) {
845   float a16[16]; int i;
846   _mm512_storeu_ps(a16, a);
847   for (i = 0; i < 16; ++i) a16[i] = LIBXSMM_EXPF(a16[i]);
848   return _mm512_loadu_ps(a16);
849 }
850 #endif /* SVML */
851 
852 /** 2048-bit state for xoshiro128+ RNG */
853 #define LIBXSMM_INTRINSICS_MM512_RNG_STATE(INDEX) (*(__m512i*)LIBXSMM_CONCATENATE(libxsmm_intrinsics_mm512_rng_state, INDEX))
854 LIBXSMM_APIVAR_PUBLIC(unsigned int libxsmm_intrinsics_mm512_rng_state0[16]);
855 LIBXSMM_APIVAR_PUBLIC(unsigned int libxsmm_intrinsics_mm512_rng_state1[16]);
856 LIBXSMM_APIVAR_PUBLIC(unsigned int libxsmm_intrinsics_mm512_rng_state2[16]);
857 LIBXSMM_APIVAR_PUBLIC(unsigned int libxsmm_intrinsics_mm512_rng_state3[16]);
858 
859 # if defined(__GNUC__) && !defined(__clang__) && !defined(LIBXSMM_INTEL_COMPILER) && !defined(_CRAYC) && 0
860 LIBXSMM_PRAGMA_OPTIMIZE_OFF /* avoid ICE in case of symbols (-g) */
861 # endif
862 /** Generate random number in the interval [0, 1); not thread-safe.
863  *  this is based on xoshiro128+ 1.0, e.g. http://prng.di.unimi.it/xoshiro128plus.c */
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)864 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512i LIBXSMM_INTRINSICS_MM512_RNG_XOSHIRO128P_EPI32(void) {
865   const __m512i result = _mm512_add_epi32(LIBXSMM_INTRINSICS_MM512_RNG_STATE(0), LIBXSMM_INTRINSICS_MM512_RNG_STATE(3));
866   const __m512i s = _mm512_slli_epi32(LIBXSMM_INTRINSICS_MM512_RNG_STATE(1), 9);
867   __m512i t;
868   LIBXSMM_INTRINSICS_MM512_RNG_STATE(2) = _mm512_xor_epi32(LIBXSMM_INTRINSICS_MM512_RNG_STATE(2), LIBXSMM_INTRINSICS_MM512_RNG_STATE(0));
869   LIBXSMM_INTRINSICS_MM512_RNG_STATE(3) = _mm512_xor_epi32(LIBXSMM_INTRINSICS_MM512_RNG_STATE(3), LIBXSMM_INTRINSICS_MM512_RNG_STATE(1));
870   LIBXSMM_INTRINSICS_MM512_RNG_STATE(1) = _mm512_xor_epi32(LIBXSMM_INTRINSICS_MM512_RNG_STATE(1), LIBXSMM_INTRINSICS_MM512_RNG_STATE(2));
871   LIBXSMM_INTRINSICS_MM512_RNG_STATE(0) = _mm512_xor_epi32(LIBXSMM_INTRINSICS_MM512_RNG_STATE(0), LIBXSMM_INTRINSICS_MM512_RNG_STATE(3));
872   LIBXSMM_INTRINSICS_MM512_RNG_STATE(2) = _mm512_xor_epi32(LIBXSMM_INTRINSICS_MM512_RNG_STATE(2), s);
873   t = _mm512_slli_epi32(LIBXSMM_INTRINSICS_MM512_RNG_STATE(3), 11);
874   LIBXSMM_INTRINSICS_MM512_RNG_STATE(3) = _mm512_or_epi32(t, _mm512_srli_epi32(LIBXSMM_INTRINSICS_MM512_RNG_STATE(3), 32 - 11));
875   return result;
876 }
877 
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)878 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_RNG_PS(void) {
879   const __m512i rng_mantissa = _mm512_srli_epi32( LIBXSMM_INTRINSICS_MM512_RNG_XOSHIRO128P_EPI32(), 9 );
880   const __m512 one = _mm512_set1_ps(1.0f);
881   return _mm512_sub_ps(_mm512_castsi512_ps(_mm512_or_epi32(_mm512_set1_epi32(0x3f800000), rng_mantissa)), one);
882 }
883 
884 /** Generate random number in the interval [0, 1); thread save, state needs to be managed by user.
885  *  this is based on xoshiro128+ 1.0, e.g. http://prng.di.unimi.it/xoshiro128plus.c */
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)886 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512i LIBXSMM_INTRINSICS_MM512_RNG_XOSHIRO128P_EXTSTATE_EPI32( unsigned int* stateptr ) {
887   __m512i state_0 = _mm512_loadu_si512( stateptr    );
888   __m512i state_1 = _mm512_loadu_si512( stateptr+16 );
889   __m512i state_2 = _mm512_loadu_si512( stateptr+32 );
890   __m512i state_3 = _mm512_loadu_si512( stateptr+48 );
891   const __m512i result = _mm512_add_epi32(state_0, state_3);
892   const __m512i s = _mm512_slli_epi32(state_1, 9);
893   __m512i t;
894   state_2 = _mm512_xor_epi32(state_2, state_0);
895   state_3 = _mm512_xor_epi32(state_3, state_1);
896   state_1 = _mm512_xor_epi32(state_1, state_2);
897   state_0 = _mm512_xor_epi32(state_0, state_3);
898   state_2 = _mm512_xor_epi32(state_2, s);
899   _mm512_storeu_si512( stateptr   , state_0 );
900   _mm512_storeu_si512( stateptr+16, state_1 );
901   _mm512_storeu_si512( stateptr+32, state_2 );
902   t = _mm512_slli_epi32(state_3, 11);
903   state_3 = _mm512_or_epi32(t, _mm512_srli_epi32(state_3, 32 - 11));
904   _mm512_storeu_si512( stateptr+48, state_3 );
905   return result;
906 }
907 
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)908 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_RNG_EXTSTATE_PS( unsigned int* stateptr) {
909   const __m512i rng_mantissa = _mm512_srli_epi32( LIBXSMM_INTRINSICS_MM512_RNG_XOSHIRO128P_EXTSTATE_EPI32( stateptr ), 9 );
910   const __m512 one = _mm512_set1_ps(1.0f);
911   return _mm512_sub_ps(_mm512_castsi512_ps(_mm512_or_epi32(_mm512_set1_epi32(0x3f800000), rng_mantissa)), one);
912 }
913 # if defined(__GNUC__) && !defined(__clang__) && !defined(LIBXSMM_INTEL_COMPILER) && !defined(_CRAYC) && 0
914 LIBXSMM_PRAGMA_OPTIMIZE_ON
915 # endif
916 #endif /*__AVX512F__*/
917 
918 #if defined(LIBXSMM_OFFLOAD_TARGET)
919 # pragma offload_attribute(pop)
920 #endif
921 
922 #endif /*LIBXSMM_INTRINSICS_X86_H*/
923 
924