1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved. *
3 * This file is part of the LIBXSMM library. *
4 * *
5 * For information on the license, see the LICENSE file. *
6 * Further information: https://github.com/hfp/libxsmm/ *
7 * SPDX-License-Identifier: BSD-3-Clause *
8 ******************************************************************************/
9 /* Hans Pabst (Intel Corp.)
10 ******************************************************************************/
11 #ifndef LIBXSMM_INTRINSICS_X86_H
12 #define LIBXSMM_INTRINSICS_X86_H
13
14 #include "libxsmm_cpuid.h"
15
16 /** https://github.com/intel/Immintrin-debug */
17 #if !defined(LIBXSMM_INTRINSICS_DEBUG) && 0
18 # define LIBXSMM_INTRINSICS_DEBUG
19 #endif
20 #if defined(LIBXSMM_INTRINSICS_DEBUG)
21 # include "immintrin_dbg.h"
22 # define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CPX
23 # if !defined(_mm512_undefined_epi32)
24 # define _mm512_undefined_epi32() _mm512_set1_epi32(0)
25 # endif
26 # if !defined(_mm256_movemask_epi8)
27 # define _mm256_movemask_epi8 mm256_movemask_epi8_dbg
mm256_movemask_epi8_dbg(__m256i k)28 LIBXSMM_API_INLINE int mm256_movemask_epi8_dbg(__m256i k) {
29 unsigned char mask[32], i; int result = 0;
30 _mm256_storeu_si256((__m256i*)mask, k);
31 for (i = 0; i < 32; ++i) result |= (mask[i] >> 7) << i;
32 return result;
33 }
34 # endif
35 # if !defined(_mm512_and_epi32)
36 # define _mm512_and_epi32 mm512_and_epi32_dbg
mm512_and_epi32_dbg(__m512i a,__m512i b)37 LIBXSMM_API_INLINE __m512i mm512_and_epi32_dbg(__m512i a, __m512i b) {
38 uint32_t a16[16], b16[16]; signed char i;
39 _mm512_storeu_si512((__m512i*)a16, a);
40 _mm512_storeu_si512((__m512i*)b16, b);
41 for (i = 0; i < 16; ++i) a16[i] &= b16[i];
42 return _mm512_loadu_si512((const __m512i*)a16);
43 }
44 # endif
45 # if !defined(_mm512_or_epi32)
46 # define _mm512_or_epi32 mm512_or_epi32_dbg
mm512_or_epi32_dbg(__m512i a,__m512i b)47 LIBXSMM_API_INLINE __m512i mm512_or_epi32_dbg(__m512i a, __m512i b) {
48 uint32_t a16[16], b16[16]; signed char i;
49 _mm512_storeu_si512((__m512i*)a16, a);
50 _mm512_storeu_si512((__m512i*)b16, b);
51 for (i = 0; i < 16; ++i) a16[i] |= b16[i];
52 return _mm512_loadu_si512((const __m512i*)a16);
53 }
54 # endif
55 # if !defined(_mm512_xor_epi32)
56 # define _mm512_xor_epi32 mm512_xor_epi32_dbg
mm512_xor_epi32_dbg(__m512i a,__m512i b)57 LIBXSMM_API_INLINE __m512i mm512_xor_epi32_dbg(__m512i a, __m512i b) {
58 uint32_t a16[16], b16[16]; signed char i;
59 _mm512_storeu_si512((__m512i*)a16, a);
60 _mm512_storeu_si512((__m512i*)b16, b);
61 for (i = 0; i < 16; ++i) a16[i] ^= b16[i];
62 return _mm512_loadu_si512((const __m512i*)a16);
63 }
64 # endif
65 # if !defined(_mm512_srli_epi32_dbg) /* GCC: avoid conflict w/ built-in */
66 # undef _mm512_srli_epi32
67 # define _mm512_srli_epi32 mm512_srli_epi32_dbg
mm512_srli_epi32_dbg(__m512i a,unsigned int imm8)68 LIBXSMM_API_INLINE __m512i mm512_srli_epi32_dbg(__m512i a, unsigned int imm8) {
69 uint32_t a16[16]; signed char i;
70 _mm512_storeu_si512((__m512i*)a16, a);
71 for (i = 0; i < 16; ++i) a16[i] >>= imm8;
72 return _mm512_loadu_si512((const __m512i*)a16);
73 }
74 # endif
75 # if !defined(_mm512_slli_epi32_dbg) /* GCC: avoid conflict w/ built-in */
76 # undef _mm512_slli_epi32
77 # define _mm512_slli_epi32 mm512_slli_epi32_dbg
mm512_slli_epi32_dbg(__m512i a,unsigned int imm8)78 LIBXSMM_API_INLINE __m512i mm512_slli_epi32_dbg(__m512i a, unsigned int imm8) {
79 uint32_t a16[16]; signed char i;
80 _mm512_storeu_si512((__m512i*)a16, a);
81 for (i = 0; i < 16; ++i) a16[i] <<= imm8;
82 return _mm512_loadu_si512((const __m512i*)a16);
83 }
84 # endif
85 # if !defined(_mm512_sub_ps)
86 # define _mm512_sub_ps mm512_sub_ps_dbg
mm512_sub_ps_dbg(__m512 a,__m512 b)87 LIBXSMM_API_INLINE __m512 mm512_sub_ps_dbg(__m512 a, __m512 b) {
88 float a16[16], b16[16]; signed char i;
89 _mm512_storeu_ps((__m512*)a16, a);
90 _mm512_storeu_ps((__m512*)b16, b);
91 for (i = 0; i < 16; ++i) a16[i] -= b16[i];
92 return _mm512_loadu_ps((const __m512*)a16);
93 }
94 # endif
95 #endif
96
97 /** Macro evaluates to LIBXSMM_ATTRIBUTE_TARGET_xxx (see below). */
98 #define LIBXSMM_ATTRIBUTE_TARGET(TARGET) LIBXSMM_CONCATENATE(LIBXSMM_ATTRIBUTE_TARGET_, TARGET)
99
100 #if /*no intrinsics: tested with 17.x and 18.x*/(defined(__PGI) && \
101 LIBXSMM_VERSION2(19, 0) > LIBXSMM_VERSION2(__PGIC__, __PGIC_MINOR__)) \
102 || /*legacy*/(defined(_CRAYC) && !defined(__GNUC__))
103 # if !defined(LIBXSMM_INTRINSICS_NONE) && !defined(LIBXSMM_INTRINSICS_STATIC)
104 # define LIBXSMM_INTRINSICS_NONE
105 # endif
106 #elif !defined(LIBXSMM_INTRINSICS_STATIC) && !defined(LIBXSMM_INTRINSICS_NONE) && ( \
107 (defined(__GNUC__) && !defined(__clang__) && !defined(LIBXSMM_INTEL_COMPILER) && !defined(_CRAYC) && \
108 LIBXSMM_VERSION2(4, 4) > LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__)) /* GCC 4.4 (target-attribute) */ \
109 || (defined(__clang__) && LIBXSMM_VERSION2(3, 7) > LIBXSMM_VERSION2(__clang_major__, __clang_minor__)) \
110 || (defined(__APPLE__) && defined(__MACH__) && !defined(LIBXSMM_INTEL_COMPILER) && defined(__clang__) && \
111 LIBXSMM_VERSION2(9, 0) > LIBXSMM_VERSION2(__clang_major__, __clang_minor__)))
112 # define LIBXSMM_INTRINSICS_STATIC
113 #endif
114
115 #if defined(LIBXSMM_OFFLOAD_TARGET)
116 # pragma offload_attribute(push,target(LIBXSMM_OFFLOAD_TARGET))
117 #endif
118
119 #if defined(__MIC__) && !defined(LIBXSMM_INTRINSICS_NONE)
120 # if !defined(LIBXSMM_STATIC_TARGET_ARCH)
121 # define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_TARGET_ARCH_GENERIC
122 # endif
123 # define LIBXSMM_INTRINSICS(TARGET)
124 # define LIBXSMM_INTRINSICS_INCLUDE
125 #elif !defined(LIBXSMM_INTRINSICS_NONE) /*!defined(__MIC__)*/
126 # if defined(__AVX512F__) && defined(__AVX512CD__) \
127 && defined(__AVX512DQ__) && defined(__AVX512BW__) && defined(__AVX512VL__) && defined(__AVX512VNNI__) && defined(__AVX512BF16__) \
128 && defined(__AVX2__) && defined(__FMA__) && defined(__AVX__) && defined(__SSE4_2__) && defined(__SSE4_1__) && defined(__SSE3__) \
129 && (!defined(__GNUC__) || defined(__clang__) || defined(LIBXSMM_INTEL_COMPILER) || defined(_CRAYC) /* TODO: check GCC, Clang, etc. */ \
130 || (LIBXSMM_VERSION2(10, 0) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__))) \
131 && (!defined(__clang__) || (LIBXSMM_VERSION2( 9, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))) \
132 && (!defined(__APPLE__) || !defined(__MACH__) || LIBXSMM_VERSION2(99, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))
133 # if !defined(LIBXSMM_STATIC_TARGET_ARCH)
134 # define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CPX
135 # endif
136 # define LIBXSMM_INTRINSICS_INCLUDE
137 # elif defined(__AVX512F__) && defined(__AVX512CD__) \
138 && defined(__AVX512DQ__) && defined(__AVX512BW__) && defined(__AVX512VL__) && defined(__AVX512VNNI__) \
139 && defined(__AVX2__) && defined(__FMA__) && defined(__AVX__) && defined(__SSE4_2__) && defined(__SSE4_1__) && defined(__SSE3__) \
140 && (!defined(__GNUC__) || defined(__clang__) || defined(LIBXSMM_INTEL_COMPILER) || defined(_CRAYC) \
141 || (LIBXSMM_VERSION2(8, 0) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__))) \
142 && (!defined(__clang__) || (LIBXSMM_VERSION2(6, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))) \
143 && (!defined(__APPLE__) || !defined(__MACH__) || LIBXSMM_VERSION2(10, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))
144 # if !defined(LIBXSMM_STATIC_TARGET_ARCH)
145 # define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CLX
146 # endif
147 # define LIBXSMM_INTRINSICS_INCLUDE
148 # elif defined(__AVX512F__) && defined(__AVX512CD__) \
149 && defined(__AVX512DQ__) && defined(__AVX512BW__) && defined(__AVX512VL__) \
150 && defined(__AVX2__) && defined(__FMA__) && defined(__AVX__) && defined(__SSE4_2__) && defined(__SSE4_1__) && defined(__SSE3__) \
151 && (!defined(__GNUC__) || defined(__clang__) || defined(LIBXSMM_INTEL_COMPILER) || defined(_CRAYC) \
152 || (LIBXSMM_VERSION2(5, 0) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__))) \
153 && (!defined(__clang__) || (LIBXSMM_VERSION2(4, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))) \
154 && (!defined(__APPLE__) || !defined(__MACH__) || LIBXSMM_VERSION2(9, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))
155 # if !defined(LIBXSMM_STATIC_TARGET_ARCH)
156 # define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CORE
157 # endif
158 # define LIBXSMM_INTRINSICS_INCLUDE
159 # elif defined(__AVX512F__) && defined(__AVX512CD__) \
160 && defined(__AVX512PF__) && defined(__AVX512ER__) \
161 && defined(__AVX2__) && defined(__FMA__) && defined(__AVX__) && defined(__SSE4_2__) && defined(__SSE4_1__) && defined(__SSE3__) \
162 && (!defined(__GNUC__) || defined(__clang__) || defined(LIBXSMM_INTEL_COMPILER) || defined(_CRAYC) \
163 || (LIBXSMM_VERSION2(5, 0) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__))) \
164 && (!defined(__clang__) || (LIBXSMM_VERSION2(4, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))) \
165 && (!defined(__APPLE__) || !defined(__MACH__) || LIBXSMM_VERSION2(9, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))
166 # if !defined(LIBXSMM_STATIC_TARGET_ARCH)
167 # define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_MIC
168 # endif
169 # define LIBXSMM_INTRINSICS_INCLUDE
170 # elif defined(__AVX512F__) && defined(__AVX512CD__) \
171 && defined(__AVX2__) && defined(__FMA__) && defined(__AVX__) && defined(__SSE4_2__) && defined(__SSE4_1__) && defined(__SSE3__) \
172 && (!defined(__GNUC__) || defined(__clang__) || defined(LIBXSMM_INTEL_COMPILER) || defined(_CRAYC) \
173 || (LIBXSMM_VERSION2(5, 0) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__))) \
174 && (!defined(__clang__) || (LIBXSMM_VERSION2(4, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))) \
175 && (!defined(__APPLE__) || !defined(__MACH__) || LIBXSMM_VERSION2(9, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))
176 # if !defined(LIBXSMM_STATIC_TARGET_ARCH)
177 # define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512
178 # endif
179 # define LIBXSMM_INTRINSICS_INCLUDE
180 # elif defined(__AVX2__) && defined(__FMA__) && defined(__AVX__) && defined(__SSE4_2__) && defined(__SSE4_1__) && defined(__SSE3__)
181 # if !defined(LIBXSMM_STATIC_TARGET_ARCH)
182 # define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_X86_AVX2
183 # endif
184 # define LIBXSMM_INTRINSICS_INCLUDE
185 # elif defined(__AVX__) && defined(__SSE4_2__) && defined(__SSE4_1__) && defined(__SSE3__)
186 # if !defined(LIBXSMM_STATIC_TARGET_ARCH)
187 # define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_X86_AVX
188 # endif
189 # define LIBXSMM_INTRINSICS_INCLUDE
190 # elif defined(__SSE4_2__) && defined(__SSE4_1__) && defined(__SSE3__)
191 # if !defined(LIBXSMM_STATIC_TARGET_ARCH)
192 # define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_X86_SSE4
193 # endif
194 # define LIBXSMM_INTRINSICS_INCLUDE
195 # elif defined(__SSE3__)
196 # if !defined(LIBXSMM_STATIC_TARGET_ARCH)
197 # define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_X86_SSE3
198 # endif
199 # define LIBXSMM_INTRINSICS_INCLUDE
200 # elif defined(LIBXSMM_PLATFORM_X86)
201 # if !defined(LIBXSMM_STATIC_TARGET_ARCH)
202 # define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_X86_GENERIC
203 # endif
204 # if defined(__GNUC__)
205 # define LIBXSMM_INTRINSICS_INCLUDE
206 # endif
207 # endif
208 # if defined(LIBXSMM_STATIC_TARGET_ARCH) && !defined(LIBXSMM_INTRINSICS_STATIC)
209 # if defined(__INTEL_COMPILER)
210 /* TODO: compiler version check for LIBXSMM_MAX_STATIC_TARGET_ARCH */
211 # if 1904 <= (LIBXSMM_INTEL_COMPILER) && !defined(_WIN32)
212 # define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CPX
213 # elif 1801 <= (LIBXSMM_INTEL_COMPILER)
214 # define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CLX
215 # elif 1500 <= (LIBXSMM_INTEL_COMPILER)
216 # define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CORE
217 # elif 1400 <= (LIBXSMM_INTEL_COMPILER)
218 # define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_MIC
219 # else
220 # define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX2
221 # endif
222 # define LIBXSMM_INTRINSICS(TARGET)/*no need for target flags*/
223 # define LIBXSMM_INTRINSICS_INCLUDE
224 # elif defined(_CRAYC) && defined(__GNUC__)
225 /* TODO: version check, e.g., LIBXSMM_VERSION2(11, 5) <= LIBXSMM_VERSION2(_RELEASE, _RELEASE_MINOR) */
226 # define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX
227 # define LIBXSMM_INTRINSICS(TARGET)/*no need for target flags*/
228 # define LIBXSMM_INTRINSICS_INCLUDE
229 # elif defined(_MSC_VER) && !defined(__clang__)
230 /* TODO: compiler version check for LIBXSMM_MAX_STATIC_TARGET_ARCH */
231 # define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX2
232 # define LIBXSMM_INTRINSICS(TARGET)/*no need for target flags*/
233 # define LIBXSMM_INTRINSICS_INCLUDE
234 # elif (!defined(__GNUC__) || LIBXSMM_VERSION2(4, 9) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__)) \
235 && (!defined(__clang__) || LIBXSMM_VERSION2(4, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__)) \
236 && (!defined(__APPLE__) || !defined(__MACH__)) && !defined(__PGI) && !defined(_MSC_VER)
237 # if defined(__CYGWIN__) && !defined(LIBXSMM_INTRINSICS_DEBUG) /* Cygwin: invalid register for .seh_savexmm */
238 # define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX2
239 # elif (defined(__clang__) && LIBXSMM_VERSION2(10, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))
240 # define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CPX
241 # elif (defined(__GNUC__) && LIBXSMM_VERSION2(10, 0) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__)) \
242 || (defined(__clang__) && LIBXSMM_VERSION2( 9, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__) && !defined(__cray__))
243 # define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CPX
244 # elif (defined(__GNUC__) && LIBXSMM_VERSION2(8, 0) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__)) \
245 || (defined(__clang__) && LIBXSMM_VERSION2(6, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))
246 # define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CLX
247 # elif (defined(__GNUC__) && LIBXSMM_VERSION2(5, 0) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__)) \
248 || (defined(__clang__) && LIBXSMM_VERSION2(6, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))
249 # define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CORE
250 # else
251 # define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX2
252 # endif
253 # define LIBXSMM_INTRINSICS_INCLUDE
254 # else /* GCC/legacy incl. Clang */
255 # if defined(__clang__) && !(defined(__APPLE__) && defined(__MACH__)) && !defined(_WIN32)
256 # if (LIBXSMM_VERSION2(7, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__)) /* TODO */
257 /* no limitations */
258 # elif (LIBXSMM_VERSION2(4, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))
259 # if !defined(LIBXSMM_INTRINSICS_STATIC) && (LIBXSMM_STATIC_TARGET_ARCH < LIBXSMM_X86_AVX2/*workaround*/)
260 # define LIBXSMM_INTRINSICS_STATIC
261 # endif
262 # elif !defined(LIBXSMM_INTRINSICS_STATIC)
263 # define LIBXSMM_INTRINSICS_STATIC
264 # endif
265 # if defined(__CYGWIN__) && !defined(LIBXSMM_INTRINSICS_DEBUG) /* Cygwin: invalid register for .seh_savexmm */
266 # define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX2
267 # elif LIBXSMM_VERSION2(10, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__)
268 # define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CPX
269 # elif LIBXSMM_VERSION2( 9, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__) && !defined(__cray__)
270 # define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CPX
271 # elif LIBXSMM_VERSION2( 6, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__)
272 # define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CLX
273 # else
274 # define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_X86_AVX512_CORE
275 # endif
276 # else /* fall-back */
277 # define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_STATIC_TARGET_ARCH
278 # if !defined(LIBXSMM_INTRINSICS_STATIC) && (LIBXSMM_STATIC_TARGET_ARCH < LIBXSMM_X86_AVX2/*workaround*/)
279 # define LIBXSMM_INTRINSICS_STATIC
280 # endif
281 # endif
282 # if !defined(LIBXSMM_INTRINSICS_INCLUDE) && (!defined(__PGI) || LIBXSMM_VERSION2(19, 0) <= LIBXSMM_VERSION2(__PGIC__, __PGIC_MINOR__))
283 # define LIBXSMM_INTRINSICS_INCLUDE
284 # endif
285 # endif /* GCC/legacy incl. Clang */
286 # if !defined(LIBXSMM_MAX_STATIC_TARGET_ARCH)
287 # error "LIBXSMM_MAX_STATIC_TARGET_ARCH not defined!"
288 # endif
289 # if defined(LIBXSMM_INTRINSICS_INCLUDE) && !defined(LIBXSMM_INTRINSICS_NONE) && !defined(LIBXSMM_INTRINSICS_DEBUG)
290 # include <immintrin.h>
291 # endif /*defined(LIBXSMM_INTRINSICS_INCLUDE)*/
292 # if !defined(LIBXSMM_INTRINSICS)
293 # if (LIBXSMM_MAX_STATIC_TARGET_ARCH > LIBXSMM_STATIC_TARGET_ARCH)
294 # define LIBXSMM_INTRINSICS(TARGET) LIBXSMM_ATTRIBUTE(LIBXSMM_ATTRIBUTE_TARGET(TARGET))
295 /* LIBXSMM_ATTRIBUTE_TARGET_xxx is required to literally match the CPUID (libxsmm_cpuid.h)! */
296 # define LIBXSMM_ATTRIBUTE_TARGET_1002 target("sse2") /* LIBXSMM_X86_GENERIC (64-bit ABI) */
297 # if (LIBXSMM_X86_SSE3 <= LIBXSMM_MAX_STATIC_TARGET_ARCH)
298 # define LIBXSMM_ATTRIBUTE_TARGET_1003 target("sse3")
299 # else
300 # define LIBXSMM_ATTRIBUTE_TARGET_1003 LIBXSMM_ATTRIBUTE_TARGET_1002
301 # endif
302 # if (LIBXSMM_X86_SSE4 <= LIBXSMM_MAX_STATIC_TARGET_ARCH)
303 # define LIBXSMM_ATTRIBUTE_TARGET_1004 target("sse4.1,sse4.2")
304 # else
305 # define LIBXSMM_ATTRIBUTE_TARGET_1004 LIBXSMM_ATTRIBUTE_TARGET_1003
306 # endif
307 # if (LIBXSMM_X86_AVX <= LIBXSMM_MAX_STATIC_TARGET_ARCH)
308 # define LIBXSMM_ATTRIBUTE_TARGET_1005 target("avx")
309 # else
310 # define LIBXSMM_ATTRIBUTE_TARGET_1005 LIBXSMM_ATTRIBUTE_TARGET_1004
311 # endif
312 # if (LIBXSMM_X86_AVX2 <= LIBXSMM_MAX_STATIC_TARGET_ARCH)
313 # define LIBXSMM_ATTRIBUTE_TARGET_1006 target("avx2,fma")
314 # else
315 # define LIBXSMM_ATTRIBUTE_TARGET_1006 LIBXSMM_ATTRIBUTE_TARGET_1005
316 # endif
317 # if (LIBXSMM_X86_AVX512 <= LIBXSMM_MAX_STATIC_TARGET_ARCH)
318 # define LIBXSMM_ATTRIBUTE_TARGET_1007 target("avx2,fma,avx512f,avx512cd")
319 # else
320 # define LIBXSMM_ATTRIBUTE_TARGET_1007 LIBXSMM_ATTRIBUTE_TARGET_1006
321 # endif
322 # if (LIBXSMM_X86_AVX512_MIC <= LIBXSMM_MAX_STATIC_TARGET_ARCH)
323 # define LIBXSMM_ATTRIBUTE_TARGET_1010 target("avx2,fma,avx512f,avx512cd,avx512pf,avx512er")
324 # else /* LIBXSMM_X86_AVX512 */
325 # define LIBXSMM_ATTRIBUTE_TARGET_1010 LIBXSMM_ATTRIBUTE_TARGET_1007
326 # endif
327 # if (LIBXSMM_X86_AVX512_KNM <= LIBXSMM_MAX_STATIC_TARGET_ARCH)
328 # define LIBXSMM_ATTRIBUTE_TARGET_1011 target("avx2,fma,avx512f,avx512cd,avx512pf,avx512er,avx5124vnniw,avx5124fmaps")
329 # else /* LIBXSMM_X86_AVX512_MIC */
330 # define LIBXSMM_ATTRIBUTE_TARGET_1011 LIBXSMM_ATTRIBUTE_TARGET_1010
331 # endif
332 # if (LIBXSMM_X86_AVX512_CORE <= LIBXSMM_MAX_STATIC_TARGET_ARCH)
333 # define LIBXSMM_ATTRIBUTE_TARGET_1020 target("avx2,fma,avx512f,avx512cd,avx512dq,avx512bw,avx512vl")
334 # else /* LIBXSMM_X86_AVX512 */
335 # define LIBXSMM_ATTRIBUTE_TARGET_1020 LIBXSMM_ATTRIBUTE_TARGET_1007
336 # endif
337 # if (LIBXSMM_X86_AVX512_CLX <= LIBXSMM_MAX_STATIC_TARGET_ARCH)
338 # define LIBXSMM_ATTRIBUTE_TARGET_1021 target("avx2,fma,avx512f,avx512cd,avx512dq,avx512bw,avx512vl,avx512vnni")
339 # else /* LIBXSMM_X86_AVX512_CORE */
340 # define LIBXSMM_ATTRIBUTE_TARGET_1021 LIBXSMM_ATTRIBUTE_TARGET_1020
341 # endif
342 # if (LIBXSMM_X86_AVX512_CPX <= LIBXSMM_MAX_STATIC_TARGET_ARCH)
343 # define LIBXSMM_ATTRIBUTE_TARGET_1022 target("avx2,fma,avx512f,avx512cd,avx512dq,avx512bw,avx512vl,avx512vnni,avx512bf16")
344 # else /* LIBXSMM_X86_AVX512_CORE */
345 # define LIBXSMM_ATTRIBUTE_TARGET_1022 LIBXSMM_ATTRIBUTE_TARGET_1021
346 # endif
347 # else
348 # define LIBXSMM_INTRINSICS(TARGET)/*no need for target flags*/
349 # endif
350 # elif !defined(LIBXSMM_INTRINSICS_TARGET)
351 # define LIBXSMM_INTRINSICS_TARGET
352 # endif /*!defined(LIBXSMM_INTRINSICS)*/
353 # endif /*defined(LIBXSMM_STATIC_TARGET_ARCH)*/
354 #endif /*!defined(LIBXSMM_INTRINSICS_NONE)*/
355
356 #if !defined(LIBXSMM_STATIC_TARGET_ARCH)
357 # if !defined(LIBXSMM_INTRINSICS_NONE) && !defined(LIBXSMM_INTRINSICS_STATIC)
358 # define LIBXSMM_INTRINSICS_NONE
359 # endif
360 # define LIBXSMM_STATIC_TARGET_ARCH LIBXSMM_TARGET_ARCH_GENERIC
361 #endif
362
363 #if !defined(LIBXSMM_MAX_STATIC_TARGET_ARCH)
364 # define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_STATIC_TARGET_ARCH
365 #elif (LIBXSMM_MAX_STATIC_TARGET_ARCH < LIBXSMM_STATIC_TARGET_ARCH)
366 # undef LIBXSMM_MAX_STATIC_TARGET_ARCH
367 # define LIBXSMM_MAX_STATIC_TARGET_ARCH LIBXSMM_STATIC_TARGET_ARCH
368 #endif
369
370 #if !defined(LIBXSMM_INTRINSICS)
371 # define LIBXSMM_INTRINSICS(TARGET)
372 #endif
373
374 /** Include basic x86 intrinsics such as __rdtsc. */
375 #if defined(LIBXSMM_INTRINSICS_INCLUDE) && !defined(LIBXSMM_INTRINSICS_DEBUG)
376 # if defined(_WIN32)
377 # include <intrin.h>
378 # elif defined(LIBXSMM_INTEL_COMPILER) || defined(_CRAYC) || defined(__clang__) || defined(__PGI)
379 # include <x86intrin.h>
380 # elif defined(__GNUC__) && (LIBXSMM_VERSION2(4, 4) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__))
381 # include <x86intrin.h>
382 # endif
383 # include <xmmintrin.h>
384 # if defined(__SSE3__)
385 # include <pmmintrin.h>
386 # endif
387 #endif
388
389 #if !defined(LIBXSMM_INTRINSICS_NONE)
390 # if defined(_WIN32)
391 # include <malloc.h>
392 # else
393 # include <mm_malloc.h>
394 # endif
395 #endif
396
397 /**
398 * Intrinsic-specific fix-ups
399 */
400 #if defined(__clang__)
401 # define LIBXSMM_INTRINSICS_LDDQU_SI128(A) _mm_loadu_si128(A)
402 #else
403 # define LIBXSMM_INTRINSICS_LDDQU_SI128(A) _mm_lddqu_si128(A)
404 #endif
405 #if !defined(LIBXSMM_INTEL_COMPILER) && defined(__clang__) && ( \
406 (LIBXSMM_VERSION2(3, 9) > LIBXSMM_VERSION2(__clang_major__, __clang_minor__)) \
407 || (LIBXSMM_VERSION2(7, 3) > LIBXSMM_VERSION2(__clang_major__, __clang_minor__) && \
408 defined(__APPLE__) && defined(__MACH__)))
409 /* prototypes with incorrect signature: _mm512_load_ps takes DP*, _mm512_load_pd takes SP* (checked with v3.8.1) */
410 # define LIBXSMM_INTRINSICS_MM512_LOAD_PS(A) _mm512_loadu_ps((const double*)(A))
411 # define LIBXSMM_INTRINSICS_MM512_LOAD_PD(A) _mm512_loadu_pd((const float*)(A))
412 /* Clang misses _mm512_stream_p? (checked with v3.8.1). */
413 # define LIBXSMM_INTRINSICS_MM512_STREAM_SI512(A, B) _mm512_store_si512(A, B)
414 # define LIBXSMM_INTRINSICS_MM512_STREAM_PS(A, B) _mm512_storeu_ps(A, B)
415 # define LIBXSMM_INTRINSICS_MM512_STREAM_PD(A, B) _mm512_store_pd(A, B)
416 #else
417 # define LIBXSMM_INTRINSICS_MM512_LOAD_PS(A) _mm512_loadu_ps((const float*)(A))
418 # define LIBXSMM_INTRINSICS_MM512_LOAD_PD(A) _mm512_loadu_pd((const double*)(A))
419 # define LIBXSMM_INTRINSICS_MM512_STREAM_SI512(A, B) _mm512_stream_si512((__m512i*)(A), (B))
420 # define LIBXSMM_INTRINSICS_MM512_STREAM_PS(A, B) _mm512_stream_ps(A, B)
421 # define LIBXSMM_INTRINSICS_MM512_STREAM_PD(A, B) _mm512_stream_pd(A, B)
422 #endif
423 #if !defined(LIBXSMM_INTEL_COMPILER) || (defined(__clang__) && ( \
424 (LIBXSMM_VERSION2(8, 0) > LIBXSMM_VERSION2(__clang_major__, __clang_minor__)))) \
425 || (defined(__APPLE__) && defined(__MACH__)) || defined(__GNUC__)
426 # define LIBXSMM_INTRINSICS_MM256_STORE_EPI32(A, B) _mm256_storeu_si256((__m256i*)(A), B)
427 #else
428 # define LIBXSMM_INTRINSICS_MM256_STORE_EPI32(A, B) _mm256_storeu_epi32(A, B)
429 #endif
430 #if defined(LIBXSMM_INTEL_COMPILER)
431 # if 1600 <= (LIBXSMM_INTEL_COMPILER)
432 # define LIBXSMM_INTRINSICS_MM512_SET_EPI16(E31, E30, E29, E28, E27, E26, E25, E24, E23, E22, E21, E20, E19, E18, E17, E16, \
433 E15, E14, E13, E12, E11, E10, E9, E8, E7, E6, E5, E4, E3, E2, E1, E0) \
434 _mm512_set_epi16(E31, E30, E29, E28, E27, E26, E25, E24, E23, E22, E21, E20, E19, E18, E17, E16, \
435 E15, E14, E13, E12, E11, E10, E9, E8, E7, E6, E5, E4, E3, E2, E1, E0)
436 # else
437 # define LIBXSMM_INTRINSICS_MM512_SET_EPI16(E31, E30, E29, E28, E27, E26, E25, E24, E23, E22, E21, E20, E19, E18, E17, E16, \
438 E15, E14, E13, E12, E11, E10, E9, E8, E7, E6, E5, E4, E3, E2, E1, E0) \
439 _mm512_castps_si512(_mm512_set_epi16(E31, E30, E29, E28, E27, E26, E25, E24, E23, E22, E21, E20, E19, E18, E17, E16, \
440 E15, E14, E13, E12, E11, E10, E9, E8, E7, E6, E5, E4, E3, E2, E1, E0))
441 # endif
442 #else
443 # define LIBXSMM_INTRINSICS_MM512_SET_EPI16(E31, E30, E29, E28, E27, E26, E25, E24, E23, E22, E21, E20, E19, E18, E17, E16, \
444 E15, E14, E13, E12, E11, E10, E9, E8, E7, E6, E5, E4, E3, E2, E1, E0) \
445 _mm512_set_epi32(((E31) << 16) | (E30), ((E29) << 16) | (E28), ((E27) << 16) | (E26), ((E25) << 16) | (E24), \
446 ((E23) << 16) | (E22), ((E21) << 16) | (E20), ((E19) << 16) | (E18), ((E17) << 16) | (E16), \
447 ((E15) << 16) | (E14), ((E13) << 16) | (E12), ((E11) << 16) | (E10), ((E9) << 16) | (E8), \
448 ((E7) << 16) | (E6), ((E5) << 16) | (E4), ((E3) << 16) | (E2), ((E1) << 16) | (E0))
449 #endif
450 #if defined(LIBXSMM_INTEL_COMPILER) \
451 || (defined(__GNUC__) && LIBXSMM_VERSION2(7, 0) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__)) \
452 || (defined(__clang__) && (!defined(__APPLE__) || !defined(__MACH__)) \
453 && LIBXSMM_VERSION2(4, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))
454 # define LIBXSMM_INTRINSICS_MM512_MASK_I32GATHER_EPI32(A, B, C, D, E) _mm512_mask_i32gather_epi32(A, B, C, D, E)
455 # define LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(A, B) _mm512_extracti64x4_epi64(A, B)
456 # define LIBXSMM_INTRINSICS_MM512_ABS_PS(A) _mm512_abs_ps(A)
457 # define LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32() _mm512_undefined_epi32()
458 # define LIBXSMM_INTRINSICS_MM512_UNDEFINED() _mm512_undefined()
459 # define LIBXSMM_INTRINSICS_MM_UNDEFINED_PD() _mm_undefined_pd()
460 #else
461 # define LIBXSMM_INTRINSICS_MM512_MASK_I32GATHER_EPI32(A, B, C, D, E) _mm512_castps_si512(_mm512_mask_i32gather_ps( \
462 _mm512_castsi512_ps(A), B, C, (const float*)(D), E))
463 # define LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(A, B) _mm256_castpd_si256(_mm512_extractf64x4_pd(_mm512_castsi512_pd(A), B))
464 # define LIBXSMM_INTRINSICS_MM512_ABS_PS(A) _mm512_castsi512_ps(_mm512_and_epi32( \
465 _mm512_castps_si512(A), _mm512_set1_epi32(0x7FFFFFFF)))
466 # define LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32() _mm512_set1_epi32(0)
467 # define LIBXSMM_INTRINSICS_MM512_UNDEFINED() _mm512_set1_ps(0)
468 # define LIBXSMM_INTRINSICS_MM_UNDEFINED_PD() _mm_set1_pd(0)
469 #endif
470 #if (defined(LIBXSMM_INTEL_COMPILER) && (1800 <= (LIBXSMM_INTEL_COMPILER))) \
471 || (!defined(LIBXSMM_INTEL_COMPILER) && defined(__GNUC__) \
472 && LIBXSMM_VERSION2(7, 0) <= LIBXSMM_VERSION2(__GNUC__, __GNUC_MINOR__)) \
473 || ((!defined(__APPLE__) || !defined(__MACH__)) && defined(__clang__) \
474 && LIBXSMM_VERSION2(8, 0) <= LIBXSMM_VERSION2(__clang_major__, __clang_minor__))
475 # define LIBXSMM_INTRINSICS_MM512_STORE_MASK(DST_PTR, SRC, NBITS) \
476 LIBXSMM_CONCATENATE(_store_mask, NBITS)((LIBXSMM_CONCATENATE(__mmask, NBITS)*)(DST_PTR), SRC)
477 # define LIBXSMM_INTRINSICS_MM512_LOAD_MASK(SRC_PTR, NBITS) \
478 LIBXSMM_CONCATENATE(_load_mask, NBITS)((/*const*/ LIBXSMM_CONCATENATE(__mmask, NBITS)*)(SRC_PTR))
479 # define LIBXSMM_INTRINSICS_MM512_CVTU32_MASK(A, NBITS) LIBXSMM_CONCATENATE(_cvtu32_mask, NBITS)((unsigned int)(A))
480 #elif defined(LIBXSMM_INTEL_COMPILER)
481 # define LIBXSMM_INTRINSICS_MM512_STORE_MASK(DST_PTR, SRC, NBITS) \
482 (*(LIBXSMM_CONCATENATE(__mmask, NBITS)*)(DST_PTR) = (LIBXSMM_CONCATENATE(__mmask, NBITS))(SRC))
483 # define LIBXSMM_INTRINSICS_MM512_LOAD_MASK(SRC_PTR, NBITS) \
484 ((LIBXSMM_CONCATENATE(__mmask, NBITS))_mm512_mask2int(*(const __mmask16*)(SRC_PTR)))
485 # define LIBXSMM_INTRINSICS_MM512_CVTU32_MASK(A, NBITS) LIBXSMM_CONCATENATE(LIBXSMM_INTRINSICS_MM512_CVTU32_MASK_, NBITS)(A)
486 # define LIBXSMM_INTRINSICS_MM512_CVTU32_MASK_32(A) ((__mmask32)(A))
487 # define LIBXSMM_INTRINSICS_MM512_CVTU32_MASK_16(A) _mm512_int2mask((int)(A))
488 # define LIBXSMM_INTRINSICS_MM512_CVTU32_MASK_8(A) ((__mmask8)(A))
489 #else
490 # define LIBXSMM_INTRINSICS_MM512_STORE_MASK(DST_PTR, SRC, NBITS) \
491 (*(LIBXSMM_CONCATENATE(__mmask, NBITS)*)(DST_PTR) = (LIBXSMM_CONCATENATE(__mmask, NBITS))(SRC))
492 # define LIBXSMM_INTRINSICS_MM512_LOAD_MASK(SRC_PTR, NBITS) (*(const LIBXSMM_CONCATENATE(__mmask, NBITS)*)(SRC_PTR))
493 # define LIBXSMM_INTRINSICS_MM512_CVTU32_MASK(A, NBITS) ((LIBXSMM_CONCATENATE(__mmask, NBITS))(A))
494 #endif
495 #define LIBXSMM_INTRINSICS_MM512_STORE_MASK64(DST_PTR, SRC) LIBXSMM_INTRINSICS_MM512_STORE_MASK(DST_PTR, SRC, 64)
496 #define LIBXSMM_INTRINSICS_MM512_STORE_MASK32(DST_PTR, SRC) LIBXSMM_INTRINSICS_MM512_STORE_MASK(DST_PTR, SRC, 32)
497 #define LIBXSMM_INTRINSICS_MM512_STORE_MASK16(DST_PTR, SRC) LIBXSMM_INTRINSICS_MM512_STORE_MASK(DST_PTR, SRC, 16)
498 #define LIBXSMM_INTRINSICS_MM512_STORE_MASK8(DST_PTR, SRC) LIBXSMM_INTRINSICS_MM512_STORE_MASK(DST_PTR, SRC, 8)
499 #define LIBXSMM_INTRINSICS_MM512_LOAD_MASK64(SRC_PTR) LIBXSMM_INTRINSICS_MM512_LOAD_MASK(SRC_PTR, 64)
500 #define LIBXSMM_INTRINSICS_MM512_LOAD_MASK32(SRC_PTR) LIBXSMM_INTRINSICS_MM512_LOAD_MASK(SRC_PTR, 32)
501 #define LIBXSMM_INTRINSICS_MM512_LOAD_MASK16(SRC_PTR) LIBXSMM_INTRINSICS_MM512_LOAD_MASK(SRC_PTR, 16)
502 #define LIBXSMM_INTRINSICS_MM512_LOAD_MASK8(SRC_PTR) LIBXSMM_INTRINSICS_MM512_LOAD_MASK(SRC_PTR, 8)
503 #define LIBXSMM_INTRINSICS_MM512_CVTU32_MASK32(A) LIBXSMM_INTRINSICS_MM512_CVTU32_MASK(A, 32)
504 #define LIBXSMM_INTRINSICS_MM512_CVTU32_MASK16(A) LIBXSMM_INTRINSICS_MM512_CVTU32_MASK(A, 16)
505 #define LIBXSMM_INTRINSICS_MM512_CVTU32_MASK8(A) LIBXSMM_INTRINSICS_MM512_CVTU32_MASK(A, 8)
506
507 /**
508 * Pseudo intrinsics for portability
509 */
LIBXSMM_INTRINSICS_BITSCANFWD32_SW(unsigned int n)510 LIBXSMM_API_INLINE int LIBXSMM_INTRINSICS_BITSCANFWD32_SW(unsigned int n) {
511 unsigned int i, r = 0; if (0 != n) for (i = 1; 0 == (n & i); i <<= 1) { ++r; } return r;
512 }
LIBXSMM_INTRINSICS_BITSCANFWD64_SW(unsigned long long n)513 LIBXSMM_API_INLINE int LIBXSMM_INTRINSICS_BITSCANFWD64_SW(unsigned long long n) {
514 unsigned int i, r = 0; if (0 != n) for (i = 1; 0 == (n & i); i <<= 1) { ++r; } return r;
515 }
516
517 /** Binary Logarithm (based on Stackoverflow's NBITSx macro). */
518 #define LIBXSMM_INTRINSICS_BITSCANBWD_SW02(N) (0 != ((N) & 0x2/*0b10*/) ? 1 : 0)
519 #define LIBXSMM_INTRINSICS_BITSCANBWD_SW04(N) (0 != ((N) & 0xC/*0b1100*/) ? (2 | LIBXSMM_INTRINSICS_BITSCANBWD_SW02((N) >> 2)) : LIBXSMM_INTRINSICS_BITSCANBWD_SW02(N))
520 #define LIBXSMM_INTRINSICS_BITSCANBWD_SW08(N) (0 != ((N) & 0xF0/*0b11110000*/) ? (4 | LIBXSMM_INTRINSICS_BITSCANBWD_SW04((N) >> 4)) : LIBXSMM_INTRINSICS_BITSCANBWD_SW04(N))
521 #define LIBXSMM_INTRINSICS_BITSCANBWD_SW16(N) (0 != ((N) & 0xFF00) ? (8 | LIBXSMM_INTRINSICS_BITSCANBWD_SW08((N) >> 8)) : LIBXSMM_INTRINSICS_BITSCANBWD_SW08(N))
522 #define LIBXSMM_INTRINSICS_BITSCANBWD_SW32(N) (0 != ((N) & 0xFFFF0000) ? (16 | LIBXSMM_INTRINSICS_BITSCANBWD_SW16((N) >> 16)) : LIBXSMM_INTRINSICS_BITSCANBWD_SW16(N))
523 #define LIBXSMM_INTRINSICS_BITSCANBWD_SW64(N) (0 != ((N) & 0xFFFFFFFF00000000) ? (32 | LIBXSMM_INTRINSICS_BITSCANBWD_SW32((N) >> 32)) : LIBXSMM_INTRINSICS_BITSCANBWD_SW32(N))
524 #define LIBXSMM_INTRINSICS_BITSCANBWD32_SW(N) LIBXSMM_INTRINSICS_BITSCANBWD_SW32((unsigned int)(N))
525 #define LIBXSMM_INTRINSICS_BITSCANBWD64_SW(N) LIBXSMM_INTRINSICS_BITSCANBWD_SW64((unsigned long long)(N))
526
527 #if defined(_WIN32) && !defined(LIBXSMM_INTRINSICS_NONE)
LIBXSMM_INTRINSICS_BITSCANFWD32(unsigned int n)528 LIBXSMM_API_INLINE unsigned int LIBXSMM_INTRINSICS_BITSCANFWD32(unsigned int n) {
529 unsigned long r = 0; _BitScanForward(&r, n); return (0 != n) * r;
530 }
LIBXSMM_INTRINSICS_BITSCANBWD32(unsigned int n)531 LIBXSMM_API_INLINE unsigned int LIBXSMM_INTRINSICS_BITSCANBWD32(unsigned int n) {
532 unsigned long r = 0; _BitScanReverse(&r, n); return r;
533 }
534 # if defined(_WIN64)
LIBXSMM_INTRINSICS_BITSCANFWD64(unsigned long long n)535 LIBXSMM_API_INLINE unsigned int LIBXSMM_INTRINSICS_BITSCANFWD64(unsigned long long n) {
536 unsigned long r = 0; _BitScanForward64(&r, n); return (0 != n) * r;
537 }
LIBXSMM_INTRINSICS_BITSCANBWD64(unsigned long long n)538 LIBXSMM_API_INLINE unsigned int LIBXSMM_INTRINSICS_BITSCANBWD64(unsigned long long n) {
539 unsigned long r = 0; _BitScanReverse64(&r, n); return r;
540 }
541 # else
542 # define LIBXSMM_INTRINSICS_BITSCANFWD64 LIBXSMM_INTRINSICS_BITSCANFWD64_SW
543 # define LIBXSMM_INTRINSICS_BITSCANBWD64 LIBXSMM_INTRINSICS_BITSCANBWD64_SW
544 # endif
545 #elif defined(__GNUC__) && !defined(LIBXSMM_INTRINSICS_NONE)
546 # define LIBXSMM_INTRINSICS_BITSCANFWD32(N) ((0 != (N)) * __builtin_ctz(N))
547 # define LIBXSMM_INTRINSICS_BITSCANFWD64(N) ((0 != (N)) * __builtin_ctzll(N))
548 # define LIBXSMM_INTRINSICS_BITSCANBWD32(N) ((0 != (N)) * (31 - __builtin_clz(N)))
549 # define LIBXSMM_INTRINSICS_BITSCANBWD64(N) ((0 != (N)) * (63 - __builtin_clzll(N)))
550 #else /* fall-back implementation */
551 # define LIBXSMM_INTRINSICS_BITSCANFWD32 LIBXSMM_INTRINSICS_BITSCANFWD32_SW
552 # define LIBXSMM_INTRINSICS_BITSCANFWD64 LIBXSMM_INTRINSICS_BITSCANFWD64_SW
553 # define LIBXSMM_INTRINSICS_BITSCANBWD32 LIBXSMM_INTRINSICS_BITSCANBWD32_SW
554 # define LIBXSMM_INTRINSICS_BITSCANBWD64 LIBXSMM_INTRINSICS_BITSCANBWD64_SW
555 #endif
556
557 /** LIBXSMM_NBITS determines the minimum number of bits needed to represent N. */
558 #define LIBXSMM_NBITS(N) (LIBXSMM_INTRINSICS_BITSCANBWD64(N) + LIBXSMM_MIN(1, N))
559 #define LIBXSMM_ISQRT2(N) ((unsigned int)((1ULL << (LIBXSMM_NBITS(N) >> 1)) /*+ LIBXSMM_MIN(1, N)*/))
560 /** LIBXSMM_ILOG2 definition matches ceil(log2(N)). */
LIBXSMM_ILOG2(unsigned long long n)561 LIBXSMM_API_INLINE unsigned int LIBXSMM_ILOG2(unsigned long long n) {
562 unsigned int result = 0; if (1 < n) {
563 const unsigned int m = LIBXSMM_INTRINSICS_BITSCANBWD64(n);
564 result = m + ((unsigned int)LIBXSMM_INTRINSICS_BITSCANBWD64(n - 1) == m);
565 } return result;
566 }
567
568 /**
569 * Target attribution
570 */
571 #if !defined(LIBXSMM_INTRINSICS_KNC) && !defined(LIBXSMM_INTRINSICS_NONE) && defined(__MIC__)
572 # define LIBXSMM_INTRINSICS_KNC
573 #endif
574 /** LIBXSMM_INTRINSICS_X86 is defined only if the compiler is able to generate this code without special flags. */
575 #if !defined(LIBXSMM_INTRINSICS_X86) && !defined(LIBXSMM_INTRINSICS_NONE) && (LIBXSMM_X86_GENERIC <= LIBXSMM_STATIC_TARGET_ARCH || \
576 (!defined(LIBXSMM_INTRINSICS_STATIC) && LIBXSMM_X86_GENERIC <= LIBXSMM_MAX_STATIC_TARGET_ARCH))
577 # define LIBXSMM_INTRINSICS_X86
578 #endif
579 /** LIBXSMM_INTRINSICS_SSE3 is defined only if the compiler is able to generate this code without special flags. */
580 #if !defined(LIBXSMM_INTRINSICS_SSE3) && !defined(LIBXSMM_INTRINSICS_NONE) && (LIBXSMM_X86_SSE3 <= LIBXSMM_STATIC_TARGET_ARCH || \
581 (!defined(LIBXSMM_INTRINSICS_STATIC) && LIBXSMM_X86_SSE3 <= LIBXSMM_MAX_STATIC_TARGET_ARCH))
582 # define LIBXSMM_INTRINSICS_SSE3
583 #endif
584 /** LIBXSMM_INTRINSICS_SSE4 is defined only if the compiler is able to generate this code without special flags. */
585 #if !defined(LIBXSMM_INTRINSICS_SSE4) && !defined(LIBXSMM_INTRINSICS_NONE) && (LIBXSMM_X86_SSE4 <= LIBXSMM_STATIC_TARGET_ARCH || \
586 (!defined(LIBXSMM_INTRINSICS_STATIC) && LIBXSMM_X86_SSE4 <= LIBXSMM_MAX_STATIC_TARGET_ARCH))
587 # define LIBXSMM_INTRINSICS_SSE4
588 #endif
589 /** LIBXSMM_INTRINSICS_AVX is defined only if the compiler is able to generate this code without special flags. */
590 #if !defined(LIBXSMM_INTRINSICS_AVX) && !defined(LIBXSMM_INTRINSICS_NONE) && (LIBXSMM_X86_AVX <= LIBXSMM_STATIC_TARGET_ARCH || \
591 (!defined(LIBXSMM_INTRINSICS_STATIC) && LIBXSMM_X86_AVX <= LIBXSMM_MAX_STATIC_TARGET_ARCH))
592 # define LIBXSMM_INTRINSICS_AVX
593 #endif
594 /** LIBXSMM_INTRINSICS_AVX2 is defined only if the compiler is able to generate this code without special flags. */
595 #if !defined(LIBXSMM_INTRINSICS_AVX2) && !defined(LIBXSMM_INTRINSICS_NONE) && (LIBXSMM_X86_AVX2 <= LIBXSMM_STATIC_TARGET_ARCH || \
596 (!defined(LIBXSMM_INTRINSICS_STATIC) && LIBXSMM_X86_AVX2 <= LIBXSMM_MAX_STATIC_TARGET_ARCH))
597 # define LIBXSMM_INTRINSICS_AVX2
598 #endif
599 /** LIBXSMM_INTRINSICS_AVX512 is defined only if the compiler is able to generate this code without special flags. */
600 #if !defined(LIBXSMM_INTRINSICS_AVX512) && !defined(LIBXSMM_INTRINSICS_NONE) && (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH || \
601 (!defined(LIBXSMM_INTRINSICS_STATIC) && LIBXSMM_X86_AVX512 <= LIBXSMM_MAX_STATIC_TARGET_ARCH))
602 # define LIBXSMM_INTRINSICS_AVX512
603 #endif
604 /** LIBXSMM_INTRINSICS_AVX512_MIC is defined only if the compiler is able to generate this code without special flags. */
605 #if !defined(LIBXSMM_INTRINSICS_AVX512_MIC) && !defined(LIBXSMM_INTRINSICS_NONE) && (LIBXSMM_X86_AVX512_MIC <= LIBXSMM_STATIC_TARGET_ARCH || \
606 (!defined(LIBXSMM_INTRINSICS_STATIC) && LIBXSMM_X86_AVX512_MIC <= LIBXSMM_MAX_STATIC_TARGET_ARCH))
607 # define LIBXSMM_INTRINSICS_AVX512_MIC
608 #endif
609 /** LIBXSMM_INTRINSICS_AVX512_KNM is defined only if the compiler is able to generate this code without special flags. */
610 #if !defined(LIBXSMM_INTRINSICS_AVX512_KNM) && !defined(LIBXSMM_INTRINSICS_NONE) && (LIBXSMM_X86_AVX512_KNM <= LIBXSMM_STATIC_TARGET_ARCH || \
611 (!defined(LIBXSMM_INTRINSICS_STATIC) && LIBXSMM_X86_AVX512_KNM <= LIBXSMM_MAX_STATIC_TARGET_ARCH))
612 # define LIBXSMM_INTRINSICS_AVX512_KNM
613 #endif
614 /** LIBXSMM_INTRINSICS_AVX512_CORE is defined only if the compiler is able to generate this code without special flags. */
615 #if !defined(LIBXSMM_INTRINSICS_AVX512_CORE) && !defined(LIBXSMM_INTRINSICS_NONE) && (LIBXSMM_X86_AVX512_CORE <= LIBXSMM_STATIC_TARGET_ARCH || \
616 (!defined(LIBXSMM_INTRINSICS_STATIC) && LIBXSMM_X86_AVX512_CORE <= LIBXSMM_MAX_STATIC_TARGET_ARCH))
617 # define LIBXSMM_INTRINSICS_AVX512_CORE
618 #endif
619 /** LIBXSMM_INTRINSICS_AVX512_CLX is defined only if the compiler is able to generate this code without special flags. */
620 #if !defined(LIBXSMM_INTRINSICS_AVX512_CLX) && !defined(LIBXSMM_INTRINSICS_NONE) && (LIBXSMM_X86_AVX512_CLX <= LIBXSMM_STATIC_TARGET_ARCH || \
621 (!defined(LIBXSMM_INTRINSICS_STATIC) && LIBXSMM_X86_AVX512_CLX <= LIBXSMM_MAX_STATIC_TARGET_ARCH))
622 # define LIBXSMM_INTRINSICS_AVX512_CLX
623 #endif
624 /** LIBXSMM_INTRINSICS_AVX512_CPX is defined only if the compiler is able to generate this code without special flags. */
625 #if !defined(LIBXSMM_INTRINSICS_AVX512_CPX) && !defined(LIBXSMM_INTRINSICS_NONE) && defined(LIBXSMM_X86_AVX512_CPX) && \
626 !defined(LIBXSMM_INTRINSICS_STATIC) && (LIBXSMM_X86_AVX512_CPX <= LIBXSMM_MAX_STATIC_TARGET_ARCH)
627 # define LIBXSMM_INTRINSICS_AVX512_CPX
628 #endif
629
630 /**
631 * Pseudo intrinsics (AVX-512)
632 */
633 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
634 # define LIBXSMM_INTRINSICS_MM512_QUANTIZE_NEAR_PS_EPI16( A, B ) _mm512_cvtepi32_epi16(_mm512_cvt_roundps_epi32( \
635 _mm512_mul_ps(LIBXSMM_INTRINSICS_MM512_LOAD_PS(A), B), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC))
636
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)637 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512i LIBXSMM_INTRINSICS_MM512_ROUNDNE_BF16(__m512 a) {
638 const __m512i vnaninf = _mm512_set1_epi32(0x7f800000), vrneadd = _mm512_set1_epi32(0x00007fff);
639 const __m512i vfixup = _mm512_set1_epi32(0x00000001), vfixupmask = _mm512_set1_epi32(0x00010000);
640 const __m512i mm512_roundbf16rne_a_ = _mm512_castps_si512(a);
641 const __mmask16 mm512_roundbf16rne_mask1_ = _mm512_cmp_epi32_mask(_mm512_and_epi32(mm512_roundbf16rne_a_, vnaninf), vnaninf, _MM_CMPINT_NE);
642 const __mmask16 mm512_roundbf16rne_mask2_ = _mm512_cmp_epi32_mask(_mm512_and_epi32(mm512_roundbf16rne_a_, vfixupmask), vfixupmask, _MM_CMPINT_EQ);
643 return _mm512_mask_add_epi32(mm512_roundbf16rne_a_, mm512_roundbf16rne_mask1_, mm512_roundbf16rne_a_, _mm512_mask_add_epi32(vrneadd, mm512_roundbf16rne_mask2_, vrneadd, vfixup));
644 }
645
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)646 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m256i LIBXSMM_INTRINSICS_MM512_CVT_FP32_BF16(__m512 a) {
647 return _mm512_cvtepi32_epi16(_mm512_srai_epi32(LIBXSMM_INTRINSICS_MM512_ROUNDNE_BF16(a), 16));
648 }
649
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)650 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512i LIBXSMM_INTRINSICS_MM512_CVT2_FP32_BF16(__m512 a, __m512 b) {
651 const __m256i aa = _mm512_cvtepi32_epi16(_mm512_srai_epi32(LIBXSMM_INTRINSICS_MM512_ROUNDNE_BF16(b), 16));
652 const __m256i bb = _mm512_cvtepi32_epi16(_mm512_srai_epi32(LIBXSMM_INTRINSICS_MM512_ROUNDNE_BF16(a), 16));
653 return _mm512_inserti64x4(_mm512_inserti64x4(_mm512_setzero_si512(), aa, 0), bb, 1);
654 }
655
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)656 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_CVTPBH_PS(__m256i a) {
657 return _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepi16_epi32(a),16));
658 }
659
660 /** SVML-intrinsics */
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)661 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_TANH_PS_RATIONAL_78( __m512 x ) {
662 const __m512 c0 = _mm512_set1_ps(2027025.0f);
663 const __m512 c1 = _mm512_set1_ps(270270.0f);
664 const __m512 c2 = _mm512_set1_ps(6930.0f);
665 const __m512 c3 = _mm512_set1_ps(36.0f);
666 const __m512 c1_d = _mm512_set1_ps(945945.0f);
667 const __m512 c2_d = _mm512_set1_ps(51975.0f);
668 const __m512 c3_d = _mm512_set1_ps(630.0f);
669 const __m512 hi_bound = _mm512_set1_ps(4.97f);
670 const __m512 lo_bound = _mm512_set1_ps(-4.97f);
671 const __m512 ones = _mm512_set1_ps(1.0f);
672 const __m512 neg_ones = _mm512_set1_ps(-1.0f);
673
674 const __m512 x2 = _mm512_mul_ps( x, x );
675 const __m512 t1_nom = _mm512_fmadd_ps( c3, x2, c2 );
676 const __m512 t2_nom = _mm512_fmadd_ps( t1_nom, x2, c1 );
677 const __m512 t3_nom = _mm512_fmadd_ps( t2_nom, x2, c0 );
678 const __m512 nom = _mm512_mul_ps( t3_nom, x );
679 const __m512 t1_denom = _mm512_add_ps( x2, c3_d );
680 const __m512 t2_denom = _mm512_fmadd_ps( t1_denom, x2, c2_d );
681 const __m512 t3_denom = _mm512_fmadd_ps( t2_denom, x2, c1_d );
682 const __m512 denom = _mm512_fmadd_ps( t3_denom, x2, c0 );
683 const __m512 denom_rcp = _mm512_rcp14_ps( denom );
684 const __mmask16 mask_hi = _mm512_cmp_ps_mask( x, hi_bound, _CMP_GT_OQ);
685 const __mmask16 mask_lo = _mm512_cmp_ps_mask( x, lo_bound, _CMP_LT_OQ);
686 __m512 result = _mm512_mul_ps( nom, denom_rcp );
687 result = _mm512_mask_blend_ps(mask_hi, result, ones);
688 result = _mm512_mask_blend_ps(mask_lo, result, neg_ones);
689
690 return result;
691 }
692
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)693 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_TANH_PS_RATIONAL_32( __m512 x ) {
694 const __m512 c1 = _mm512_set1_ps((float)(1.0/27.0));
695 const __m512 c2 = _mm512_set1_ps((float)(1.0/3));
696 const __m512 hi_bound = _mm512_set1_ps(3.2f);
697 const __m512 lo_bound = _mm512_set1_ps(-3.2f);
698 const __m512 ones = _mm512_set1_ps(1.0f);
699 const __m512 neg_ones = _mm512_set1_ps(-1.0f);
700
701 const __m512 x2 = _mm512_mul_ps( x, x );
702 const __m512 t1_nom = _mm512_fmadd_ps( x2, c1, ones);
703 const __m512 nom = _mm512_mul_ps( t1_nom, x );
704 const __m512 denom = _mm512_fmadd_ps( x2, c2, ones);
705 const __m512 denom_rcp = _mm512_rcp14_ps( denom );
706 const __mmask16 mask_hi = _mm512_cmp_ps_mask( x, hi_bound, _CMP_GT_OQ);
707 const __mmask16 mask_lo = _mm512_cmp_ps_mask( x, lo_bound, _CMP_LT_OQ);
708 __m512 result = _mm512_mul_ps(nom, denom_rcp);
709 result = _mm512_mask_blend_ps(mask_hi, result, ones);
710 result = _mm512_mask_blend_ps(mask_lo, result, neg_ones);
711
712 return result;
713 }
714
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)715 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_TANH_PS_EXP2( __m512 _x ) {
716 const __m512 twice_log2_e = _mm512_set1_ps((float)(1.442695*2));
717 const __m512 half = _mm512_set1_ps(0.5f);
718 const __m512 c2 = _mm512_set1_ps(0.240226507f);
719 const __m512 c1 = _mm512_set1_ps(0.452920674f);
720 const __m512 c0 = _mm512_set1_ps(0.713483036f);
721 const __m512 ones = _mm512_set1_ps(1.0f);
722 const __m512 minus_twos = _mm512_set1_ps(-2.0f);
723
724 const __m512 x = _mm512_fmadd_ps(_x, twice_log2_e, half);
725 #if 1
726 const __m512 y = _mm512_sub_ps(x, _mm512_roundscale_round_ps(x, 1, _MM_FROUND_CUR_DIRECTION));
727 #else
728 const __m512 y = _mm512_reduce_ps(x, 1);
729 #endif
730 const __m512 t1 = _mm512_fmadd_ps( y, c2, c1);
731 const __m512 two_to_y = _mm512_fmadd_ps( y, t1, c0);
732 const __m512 exp = _mm512_scalef_ps( two_to_y, x );
733 const __m512 denom_rcp = _mm512_rcp14_ps( _mm512_add_ps( exp, ones) );
734 __m512 result = _mm512_fmadd_ps( denom_rcp, minus_twos, ones);
735
736 return result;
737 }
738
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)739 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_TANH_PS_EXP3( __m512 _x ) {
740 const __m512 twice_log2_e = _mm512_set1_ps((float)(1.442695*2));
741 const __m512 half = _mm512_set1_ps(0.5f);
742 const __m512 c3 = _mm512_set1_ps(0.05550410866f);
743 const __m512 c2 = _mm512_set1_ps(0.15697034396f);
744 const __m512 c1 = _mm512_set1_ps(0.49454875509f);
745 const __m512 c0 = _mm512_set1_ps(0.70654502287f);
746 const __m512 ones = _mm512_set1_ps(1.0f);
747 const __m512 minus_twos = _mm512_set1_ps(-2.0f);
748
749 const __m512 x = _mm512_fmadd_ps(_x, twice_log2_e, half);
750 #if 1
751 const __m512 y = _mm512_sub_ps(x, _mm512_roundscale_round_ps(x, 1, _MM_FROUND_CUR_DIRECTION));
752 #else
753 const __m512 y = _mm512_reduce_ps(x, 1);
754 #endif
755 const __m512 t1 = _mm512_fmadd_ps( y, c3, c2);
756 const __m512 t2 = _mm512_fmadd_ps( y, t1, c1);
757 const __m512 two_to_y = _mm512_fmadd_ps( y, t2, c0);
758 const __m512 exp = _mm512_scalef_ps( two_to_y, x );
759 const __m512 denom_rcp = _mm512_rcp14_ps( _mm512_add_ps( exp, ones) );
760 __m512 result = _mm512_fmadd_ps( denom_rcp, minus_twos, ones);
761
762 return result;
763 }
764
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)765 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX2( __m512 x ) {
766 __m512 result, func_p0, func_p1, func_p2;
767 const __m512i sign_mask = _mm512_set1_epi32( 0x80000000 );
768 const __m512i sign_filter = _mm512_set1_epi32( 0x7FFFFFFF );
769 const __m512i lut_low = _mm512_set1_epi32( 246 );
770 const __m512i lut_high = _mm512_set1_epi32( 261 );
771 const __m512 tanh_p0_2_reg = _mm512_set_ps( 0.40555000f, 0.11892800f, -0.00972979f, -0.02740300f, -0.0169851f, -0.00776152f, -0.00305889f,
772 -0.00116259f, -0.00041726f, -8.53233e-6f, 1.0000000f, 0.99999800f, 0.99975400f, 0.99268200f,
773 0.93645300f, 0.73833900f);
774 const __m512 tanh_p1_2_reg = _mm512_set_ps( 0.495602f, 0.88152f, 1.125700000f, 1.17021000f, 1.1289000000f, 1.07929000f, 1.0432300f, 1.023010f,
775 1.011620f, 1.00164f, 1.56828e-14f, 4.49924e-7f, 0.0000646924f, 0.00260405f, 0.0311608f, 0.168736f);
776 const __m512 tanh_p2_2_reg = _mm512_set_ps(-0.108238f, -0.2384280f, -0.354418000f, -0.38240300f, -0.34135700f, -0.274509000f, -0.20524900f, -0.1511960f,
777 -0.107635f, -0.0466868f, -3.60822e-16f, -2.05971e-8f, -4.24538e-6f, -0.000231709f, -0.00386434f, -0.0277702f);
778
779 const __m512i signs = _mm512_and_epi32(_mm512_castps_si512(x), sign_mask);
780 const __m512i abs_arg = _mm512_and_epi32(_mm512_castps_si512(x), sign_filter);
781 __m512i indices = _mm512_srli_epi32(abs_arg, 22);
782 indices = _mm512_max_epi32(indices, lut_low);
783 indices = _mm512_min_epi32(indices, lut_high);
784
785 func_p0 = _mm512_permutexvar_ps(indices, tanh_p0_2_reg);
786 func_p1 = _mm512_permutexvar_ps(indices, tanh_p1_2_reg);
787 func_p2 = _mm512_permutexvar_ps(indices, tanh_p2_2_reg);
788
789 result = _mm512_fmadd_ps(_mm512_castsi512_ps(abs_arg), func_p2, func_p1);
790 result = _mm512_fmadd_ps(_mm512_castsi512_ps(abs_arg), result, func_p0);
791 result = _mm512_castsi512_ps(_mm512_xor_epi32(_mm512_castps_si512(result), signs));
792
793 return result;
794 }
795
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)796 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_TANH_PS_MINIMAX3( __m512 x ) {
797 __m512 result, func_p0, func_p1, func_p2, func_p3;
798 const __m512i sign_mask = _mm512_set1_epi32( 0x80000000 );
799 const __m512i sign_filter = _mm512_set1_epi32( 0x7FFFFFFF );
800 const __m512i lut_low = _mm512_set1_epi32( 246 );
801 const __m512i lut_high = _mm512_set1_epi32( 261 );
802
803 const __m512 tanh_p0_3_reg = _mm512_setr_ps( 0.466283000f, 0.82850600f, 0.97437500f, 0.99882600f, 0.9999860f, 1.0000000f, -1.50006e-08f, -7.98169e-06f,
804 -4.53753e-05f, -0.00023755f, -0.00125285f, -0.00572314f, -0.0227717f, -0.0629089f, -0.084234300f, 0.071199800f);
805 const __m512 tanh_p1_3_reg = _mm512_setr_ps( 0.500617f, 0.124369f, 0.0137214f, 0.000464124f, 4.02465e-06f, 0.00000f, 1.00001f, 1.00028f, 1.00112f, 1.00414f,
806 1.015570f, 1.050950f, 1.1478500f, 1.310130000f, 1.378950000f, 1.07407f);
807 const __m512 tanh_p2_3_reg = _mm512_setr_ps(-0.16133200f, -0.0305526f, -0.00245909f, -6.12647e-05f, -3.76127e-07f, 0.000000f, -0.000245872f, -0.00341151f,
808 -0.00971505f, -0.0256817f, -0.06869110f, -0.162433000f, -0.346828000f, -0.566516f, -0.640214000f, -0.44011900f);
809 const __m512 tanh_p3_3_reg = _mm512_setr_ps( 0.0177393f, 0.00253432f, 0.000147303f, 2.69963e-06f, 1.16764e-08f, 0.0000000f, -0.330125f, -0.3176210f,
810 -0.3017760f, -0.27358000f, -0.219375000f, -0.136197000f, -0.01868680f, 0.0808901f, 0.107095f, 0.0631459f);
811
812 const __m512i signs = _mm512_and_epi32(_mm512_castps_si512(x), sign_mask);
813 const __m512i abs_arg = _mm512_and_epi32(_mm512_castps_si512(x), sign_filter);
814 __m512i indices = _mm512_srli_epi32(abs_arg, 22);
815 indices = _mm512_max_epi32(indices, lut_low);
816 indices = _mm512_min_epi32(indices, lut_high);
817
818 func_p0 = _mm512_permutexvar_ps(indices, tanh_p0_3_reg);
819 func_p1 = _mm512_permutexvar_ps(indices, tanh_p1_3_reg);
820 func_p2 = _mm512_permutexvar_ps(indices, tanh_p2_3_reg);
821 func_p3 = _mm512_permutexvar_ps(indices, tanh_p3_3_reg);
822
823 result = _mm512_fmadd_ps(_mm512_castsi512_ps(abs_arg), func_p3, func_p2);
824 result = _mm512_fmadd_ps(_mm512_castsi512_ps(abs_arg), result, func_p1);
825 result = _mm512_fmadd_ps(_mm512_castsi512_ps(abs_arg), result, func_p0);
826 result = _mm512_castsi512_ps(_mm512_xor_epi32(_mm512_castps_si512(result), signs));
827
828 return result;
829 }
830
831 #if defined(LIBXSMM_INTEL_COMPILER)
832 # define LIBXSMM_INTRINSICS_MM512_TANH_PS(A) _mm512_tanh_ps(A)
833 # define LIBXSMM_INTRINSICS_MM512_EXP_PS(A) _mm512_exp_ps(A)
834 #else
835 # if !defined(LIBXSMM_NO_LIBM)
836 # include <math.h>
837 # endif
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)838 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_TANH_PS(__m512 a) {
839 float a16[16]; int i;
840 _mm512_storeu_ps(a16, a);
841 for (i = 0; i < 16; ++i) a16[i] = LIBXSMM_TANHF(a16[i]);
842 return _mm512_loadu_ps(a16);
843 }
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)844 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_EXP_PS(__m512 a) {
845 float a16[16]; int i;
846 _mm512_storeu_ps(a16, a);
847 for (i = 0; i < 16; ++i) a16[i] = LIBXSMM_EXPF(a16[i]);
848 return _mm512_loadu_ps(a16);
849 }
850 #endif /* SVML */
851
852 /** 2048-bit state for xoshiro128+ RNG */
853 #define LIBXSMM_INTRINSICS_MM512_RNG_STATE(INDEX) (*(__m512i*)LIBXSMM_CONCATENATE(libxsmm_intrinsics_mm512_rng_state, INDEX))
854 LIBXSMM_APIVAR_PUBLIC(unsigned int libxsmm_intrinsics_mm512_rng_state0[16]);
855 LIBXSMM_APIVAR_PUBLIC(unsigned int libxsmm_intrinsics_mm512_rng_state1[16]);
856 LIBXSMM_APIVAR_PUBLIC(unsigned int libxsmm_intrinsics_mm512_rng_state2[16]);
857 LIBXSMM_APIVAR_PUBLIC(unsigned int libxsmm_intrinsics_mm512_rng_state3[16]);
858
859 # if defined(__GNUC__) && !defined(__clang__) && !defined(LIBXSMM_INTEL_COMPILER) && !defined(_CRAYC) && 0
860 LIBXSMM_PRAGMA_OPTIMIZE_OFF /* avoid ICE in case of symbols (-g) */
861 # endif
862 /** Generate random number in the interval [0, 1); not thread-safe.
863 * this is based on xoshiro128+ 1.0, e.g. http://prng.di.unimi.it/xoshiro128plus.c */
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)864 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512i LIBXSMM_INTRINSICS_MM512_RNG_XOSHIRO128P_EPI32(void) {
865 const __m512i result = _mm512_add_epi32(LIBXSMM_INTRINSICS_MM512_RNG_STATE(0), LIBXSMM_INTRINSICS_MM512_RNG_STATE(3));
866 const __m512i s = _mm512_slli_epi32(LIBXSMM_INTRINSICS_MM512_RNG_STATE(1), 9);
867 __m512i t;
868 LIBXSMM_INTRINSICS_MM512_RNG_STATE(2) = _mm512_xor_epi32(LIBXSMM_INTRINSICS_MM512_RNG_STATE(2), LIBXSMM_INTRINSICS_MM512_RNG_STATE(0));
869 LIBXSMM_INTRINSICS_MM512_RNG_STATE(3) = _mm512_xor_epi32(LIBXSMM_INTRINSICS_MM512_RNG_STATE(3), LIBXSMM_INTRINSICS_MM512_RNG_STATE(1));
870 LIBXSMM_INTRINSICS_MM512_RNG_STATE(1) = _mm512_xor_epi32(LIBXSMM_INTRINSICS_MM512_RNG_STATE(1), LIBXSMM_INTRINSICS_MM512_RNG_STATE(2));
871 LIBXSMM_INTRINSICS_MM512_RNG_STATE(0) = _mm512_xor_epi32(LIBXSMM_INTRINSICS_MM512_RNG_STATE(0), LIBXSMM_INTRINSICS_MM512_RNG_STATE(3));
872 LIBXSMM_INTRINSICS_MM512_RNG_STATE(2) = _mm512_xor_epi32(LIBXSMM_INTRINSICS_MM512_RNG_STATE(2), s);
873 t = _mm512_slli_epi32(LIBXSMM_INTRINSICS_MM512_RNG_STATE(3), 11);
874 LIBXSMM_INTRINSICS_MM512_RNG_STATE(3) = _mm512_or_epi32(t, _mm512_srli_epi32(LIBXSMM_INTRINSICS_MM512_RNG_STATE(3), 32 - 11));
875 return result;
876 }
877
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)878 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_RNG_PS(void) {
879 const __m512i rng_mantissa = _mm512_srli_epi32( LIBXSMM_INTRINSICS_MM512_RNG_XOSHIRO128P_EPI32(), 9 );
880 const __m512 one = _mm512_set1_ps(1.0f);
881 return _mm512_sub_ps(_mm512_castsi512_ps(_mm512_or_epi32(_mm512_set1_epi32(0x3f800000), rng_mantissa)), one);
882 }
883
884 /** Generate random number in the interval [0, 1); thread save, state needs to be managed by user.
885 * this is based on xoshiro128+ 1.0, e.g. http://prng.di.unimi.it/xoshiro128plus.c */
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)886 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512i LIBXSMM_INTRINSICS_MM512_RNG_XOSHIRO128P_EXTSTATE_EPI32( unsigned int* stateptr ) {
887 __m512i state_0 = _mm512_loadu_si512( stateptr );
888 __m512i state_1 = _mm512_loadu_si512( stateptr+16 );
889 __m512i state_2 = _mm512_loadu_si512( stateptr+32 );
890 __m512i state_3 = _mm512_loadu_si512( stateptr+48 );
891 const __m512i result = _mm512_add_epi32(state_0, state_3);
892 const __m512i s = _mm512_slli_epi32(state_1, 9);
893 __m512i t;
894 state_2 = _mm512_xor_epi32(state_2, state_0);
895 state_3 = _mm512_xor_epi32(state_3, state_1);
896 state_1 = _mm512_xor_epi32(state_1, state_2);
897 state_0 = _mm512_xor_epi32(state_0, state_3);
898 state_2 = _mm512_xor_epi32(state_2, s);
899 _mm512_storeu_si512( stateptr , state_0 );
900 _mm512_storeu_si512( stateptr+16, state_1 );
901 _mm512_storeu_si512( stateptr+32, state_2 );
902 t = _mm512_slli_epi32(state_3, 11);
903 state_3 = _mm512_or_epi32(t, _mm512_srli_epi32(state_3, 32 - 11));
904 _mm512_storeu_si512( stateptr+48, state_3 );
905 return result;
906 }
907
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)908 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512) __m512 LIBXSMM_INTRINSICS_MM512_RNG_EXTSTATE_PS( unsigned int* stateptr) {
909 const __m512i rng_mantissa = _mm512_srli_epi32( LIBXSMM_INTRINSICS_MM512_RNG_XOSHIRO128P_EXTSTATE_EPI32( stateptr ), 9 );
910 const __m512 one = _mm512_set1_ps(1.0f);
911 return _mm512_sub_ps(_mm512_castsi512_ps(_mm512_or_epi32(_mm512_set1_epi32(0x3f800000), rng_mantissa)), one);
912 }
913 # if defined(__GNUC__) && !defined(__clang__) && !defined(LIBXSMM_INTEL_COMPILER) && !defined(_CRAYC) && 0
914 LIBXSMM_PRAGMA_OPTIMIZE_ON
915 # endif
916 #endif /*__AVX512F__*/
917
918 #if defined(LIBXSMM_OFFLOAD_TARGET)
919 # pragma offload_attribute(pop)
920 #endif
921
922 #endif /*LIBXSMM_INTRINSICS_X86_H*/
923
924