1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved. *
3 * This file is part of the LIBXSMM library. *
4 * *
5 * For information on the license, see the LICENSE file. *
6 * Further information: https://github.com/hfp/libxsmm/ *
7 * SPDX-License-Identifier: BSD-3-Clause *
8 ******************************************************************************/
9 /* Alexander Heinecke, Hans Pabst (Intel Corp.)
10 ******************************************************************************/
11 #include "libxsmm_rng.h"
12 #include "libxsmm_main.h"
13
14 #if !defined(LIBXSMM_RNG_DRAND48) && (!defined(_WIN32) && !defined(__CYGWIN__) && (defined(_SVID_SOURCE) || defined(_XOPEN_SOURCE)))
15 # define LIBXSMM_RNG_DRAND48
16 #endif
17
18 #if !defined(LIBXSMM_RNG_SIMD_MIN)
19 # define LIBXSMM_RNG_SIMD_MIN 8
20 #endif
21
22 /* dispatched RNG functions (separate typedef for legacy Cray C++ needed) */
23 typedef void (*internal_rng_f32_seq_fn)(float*, libxsmm_blasint);
24 LIBXSMM_APIVAR_DEFINE(internal_rng_f32_seq_fn internal_rng_f32_seq);
25 /* 2048-bit state for RNG */
26 LIBXSMM_APIVAR_DEFINE(unsigned int internal_rng_state0[16]);
27 LIBXSMM_APIVAR_DEFINE(unsigned int internal_rng_state1[16]);
28 LIBXSMM_APIVAR_DEFINE(unsigned int internal_rng_state2[16]);
29 LIBXSMM_APIVAR_DEFINE(unsigned int internal_rng_state3[16]);
30
31
internal_rng_float_jump(uint32_t * state0,uint32_t * state1,uint32_t * state2,uint32_t * state3)32 LIBXSMM_API_INLINE void internal_rng_float_jump(uint32_t* state0, uint32_t* state1, uint32_t* state2, uint32_t* state3)
33 {
34 static const uint32_t jump_table[] = { 0x8764000b, 0xf542d2d3, 0x6fa035c3, 0x77f2db5b };
35 uint32_t s0 = 0, s1 = 0, s2 = 0, s3 = 0;
36 int i, b;
37
38 LIBXSMM_ASSERT(4 == sizeof(jump_table) / sizeof(*jump_table));
39 for (i = 0; i < 4; ++i) {
40 for (b = 0; b < 32; ++b) {
41 if (jump_table[i] & (1U << b)) {
42 s0 ^= *state0;
43 s1 ^= *state1;
44 s2 ^= *state2;
45 s3 ^= *state3;
46 }
47 { /* draw one more integer */
48 const uint32_t t = *state1 << 9;
49 *state2 ^= *state0;
50 *state3 ^= *state1;
51 *state1 ^= *state2;
52 *state0 ^= *state3;
53 *state2 ^= t;
54 *state3 = ((*state3 << 11) | (*state3 >> (32 - 11)));
55 }
56 }
57 }
58 *state0 = s0;
59 *state1 = s1;
60 *state2 = s2;
61 *state3 = s3;
62 }
63
64
internal_rng_scalar_float_next(int i)65 LIBXSMM_API_INLINE float internal_rng_scalar_float_next(int i)
66 {
67 const uint32_t rng_mantissa = (internal_rng_state0[i] + internal_rng_state3[i]) >> 9;
68 const uint32_t t = internal_rng_state1[i] << 9;
69 union { uint32_t i; float f; } rng;
70
71 internal_rng_state2[i] ^= internal_rng_state0[i];
72 internal_rng_state3[i] ^= internal_rng_state1[i];
73 internal_rng_state1[i] ^= internal_rng_state2[i];
74 internal_rng_state0[i] ^= internal_rng_state3[i];
75 internal_rng_state2[i] ^= t;
76 internal_rng_state3[i] = ((internal_rng_state3[i] << 11) | (internal_rng_state3[i] >> (32 - 11)));
77
78 rng.i = 0x3f800000 | rng_mantissa;
79 return rng.f - 1.0f;
80 }
81
82
83 LIBXSMM_API_INTERN void internal_rng_set_seed_sw(uint32_t seed);
internal_rng_set_seed_sw(uint32_t seed)84 LIBXSMM_API_INTERN void internal_rng_set_seed_sw(uint32_t seed)
85 {
86 static const uint32_t temp_state[] = {
87 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16,
88 131, 130, 129, 128, 127, 126, 125, 124, 123, 122, 121, 120, 119, 118, 117, 116,
89 231, 230, 229, 228, 227, 226, 225, 224, 223, 222, 221, 220, 219, 218, 217, 216,
90 331, 330, 329, 328, 327, 326, 325, 324, 323, 322, 321, 320, 319, 318, 317, 316
91 };
92 libxsmm_blasint i;
93
94 /* finish initializing the state */
95 LIBXSMM_ASSERT((16 * 4) == sizeof(temp_state) / sizeof(*temp_state));
96 for (i = 0; i < 16; ++i) {
97 internal_rng_state0[i] = seed + temp_state[i];
98 internal_rng_state1[i] = seed + temp_state[i+16];
99 internal_rng_state2[i] = seed + temp_state[i+32];
100 internal_rng_state3[i] = seed + temp_state[i+48];
101 }
102 for (i = 0; i < 16; ++i) {
103 internal_rng_float_jump( /* progress each sequence by 2^64 */
104 internal_rng_state0 + i, internal_rng_state1 + i,
105 internal_rng_state2 + i, internal_rng_state3 + i);
106 }
107 /* for consistency, other RNGs are seeded as well */
108 #if !defined(_WIN32) && !defined(__CYGWIN__) && (defined(_SVID_SOURCE) || defined(_XOPEN_SOURCE))
109 srand48(seed);
110 #endif
111 srand(seed);
112 }
113
114
internal_rng_f32_seq_sw(float * rngs,libxsmm_blasint count)115 LIBXSMM_API_INLINE void internal_rng_f32_seq_sw(float* rngs, libxsmm_blasint count)
116 {
117 libxsmm_blasint i = 0;
118 for (; i < count; ++i) {
119 rngs[i] = internal_rng_scalar_float_next(LIBXSMM_MOD2(i, 16));
120 }
121 }
122
123
124 #if defined(LIBXSMM_INTRINSICS_AVX512) /* __AVX512F__ */
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)125 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
126 void internal_rng_set_seed_avx512(uint32_t seed)
127 {
128 internal_rng_set_seed_sw(seed);
129 /* bring scalar state to AVX-512 */
130 LIBXSMM_INTRINSICS_MM512_RNG_STATE(0) = _mm512_loadu_si512(internal_rng_state0);
131 LIBXSMM_INTRINSICS_MM512_RNG_STATE(1) = _mm512_loadu_si512(internal_rng_state1);
132 LIBXSMM_INTRINSICS_MM512_RNG_STATE(2) = _mm512_loadu_si512(internal_rng_state2);
133 LIBXSMM_INTRINSICS_MM512_RNG_STATE(3) = _mm512_loadu_si512(internal_rng_state3);
134 }
135
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)136 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
137 void internal_rng_f32_seq_avx512(float* rngs, libxsmm_blasint count)
138 {
139 if ((LIBXSMM_RNG_SIMD_MIN << 4) <= count) { /* SIMD code path */
140 const libxsmm_blasint n = (count >> 4) << 4; /* multiple of vector-length */
141 libxsmm_blasint i = 0;
142 for (; i < n; i += 16) {
143 _mm512_storeu_ps(rngs + i, LIBXSMM_INTRINSICS_MM512_RNG_PS());
144 }
145 if (i < count) { /* remainder */
146 #if 0 /* assert(0 < n) */
147 if (0 < n)
148 #endif
149 { /* bring AVX-512 state to scalar */
150 _mm512_storeu_si512(internal_rng_state0, LIBXSMM_INTRINSICS_MM512_RNG_STATE(0));
151 _mm512_storeu_si512(internal_rng_state1, LIBXSMM_INTRINSICS_MM512_RNG_STATE(1));
152 _mm512_storeu_si512(internal_rng_state2, LIBXSMM_INTRINSICS_MM512_RNG_STATE(2));
153 _mm512_storeu_si512(internal_rng_state3, LIBXSMM_INTRINSICS_MM512_RNG_STATE(3));
154 }
155 LIBXSMM_ASSERT(count < i + 16);
156 do { /* scalar remainder */
157 rngs[i] = internal_rng_scalar_float_next(LIBXSMM_MOD2(i, 16));
158 ++i;
159 } while (i < count);
160 /* bring scalar state to AVX-512 */
161 LIBXSMM_INTRINSICS_MM512_RNG_STATE(0) = _mm512_loadu_si512(internal_rng_state0);
162 LIBXSMM_INTRINSICS_MM512_RNG_STATE(1) = _mm512_loadu_si512(internal_rng_state1);
163 LIBXSMM_INTRINSICS_MM512_RNG_STATE(2) = _mm512_loadu_si512(internal_rng_state2);
164 LIBXSMM_INTRINSICS_MM512_RNG_STATE(3) = _mm512_loadu_si512(internal_rng_state3);
165 }
166 }
167 else { /* scalar code path */
168 internal_rng_f32_seq_sw(rngs, count);
169 }
170 }
171 #endif /*defined(LIBXSMM_INTRINSICS_AVX512)*/
172
173
libxsmm_rng_create_avx512_extstate(unsigned int seed)174 LIBXSMM_API unsigned int* libxsmm_rng_create_avx512_extstate(unsigned int/*uint32_t*/ seed)
175 {
176 unsigned int* state = (unsigned int*) libxsmm_aligned_malloc( 64*sizeof(unsigned int), 64 );
177 static const uint32_t temp_state[] = {
178 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16,
179 131, 130, 129, 128, 127, 126, 125, 124, 123, 122, 121, 120, 119, 118, 117, 116,
180 231, 230, 229, 228, 227, 226, 225, 224, 223, 222, 221, 220, 219, 218, 217, 216,
181 331, 330, 329, 328, 327, 326, 325, 324, 323, 322, 321, 320, 319, 318, 317, 316
182 };
183 libxsmm_blasint i;
184
185 /* finish initializing the state */
186 LIBXSMM_ASSERT((16 * 4) == sizeof(temp_state) / sizeof(*temp_state));
187 for (i = 0; i < 16; ++i) {
188 state[i ] = seed + temp_state[i];
189 state[i+16] = seed + temp_state[i+16];
190 state[i+32] = seed + temp_state[i+32];
191 state[i+48] = seed + temp_state[i+48];
192 }
193 for (i = 0; i < 16; ++i) {
194 internal_rng_float_jump( /* progress each sequence by 2^64 */
195 state + i, state + 16 + i,
196 state + 32 + i, state + 48 + i);
197 }
198
199 return state;
200 }
201
202
libxsmm_rng_destroy_avx512_extstate(unsigned int * stateptr)203 LIBXSMM_API void libxsmm_rng_destroy_avx512_extstate(unsigned int* stateptr)
204 {
205 if ( stateptr != NULL ) {
206 libxsmm_free( stateptr );
207 }
208 }
209
210
libxsmm_rng_set_seed(unsigned int seed)211 LIBXSMM_API void libxsmm_rng_set_seed(unsigned int/*uint32_t*/ seed)
212 {
213 LIBXSMM_INIT
214 #if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH)
215 # if !defined(NDEBUG) /* used to track if seed is initialized */
216 internal_rng_f32_seq = internal_rng_f32_seq_avx512;
217 # endif
218 internal_rng_set_seed_avx512(seed);
219 #elif defined(LIBXSMM_INTRINSICS_AVX512) /* __AVX512F__ */
220 if (LIBXSMM_X86_AVX512 <= libxsmm_target_archid) {
221 internal_rng_f32_seq = internal_rng_f32_seq_avx512;
222 internal_rng_set_seed_avx512(seed);
223 }
224 else {
225 internal_rng_f32_seq = internal_rng_f32_seq_sw;
226 internal_rng_set_seed_sw(seed);
227 }
228 #else
229 # if !defined(NDEBUG) /* used to track if seed is initialized */
230 internal_rng_f32_seq = internal_rng_f32_seq_sw;
231 # endif
232 internal_rng_set_seed_sw(seed);
233 #endif
234 }
235
236
libxsmm_rng_f32_seq(float * rngs,libxsmm_blasint count)237 LIBXSMM_API void libxsmm_rng_f32_seq(float* rngs, libxsmm_blasint count)
238 {
239 LIBXSMM_ASSERT_MSG(NULL != internal_rng_f32_seq, "RNG must be initialized");
240 #if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH)
241 internal_rng_f32_seq_avx512(rngs, count);
242 #else
243 # if defined(LIBXSMM_INTRINSICS_AVX512) /* __AVX512F__ */
244 if ((LIBXSMM_RNG_SIMD_MIN << 4) <= count) { /* SIMD code path */
245 internal_rng_f32_seq(rngs, count); /* pointer based function call */
246 }
247 else /* scalar code path */
248 # endif
249 internal_rng_f32_seq_sw(rngs, count);
250 #endif
251 }
252
253
libxsmm_rng_u32(unsigned int n)254 LIBXSMM_API unsigned int libxsmm_rng_u32(unsigned int n)
255 {
256 #if defined(LIBXSMM_RNG_DRAND48)
257 const unsigned int q = ((1U << 31) / n) * n;
258 unsigned int r = (unsigned int)lrand48();
259 if (q != (1U << 31))
260 #else
261 const unsigned int rand_max1 = (unsigned int)(RAND_MAX)+1U;
262 const unsigned int q = (rand_max1 / n) * n;
263 unsigned int r = (unsigned int)rand();
264 if (q != rand_max1)
265 #endif
266 {
267 #if defined(LIBXSMM_RNG_DRAND48)
268 /* coverity[dont_call] */
269 while (q <= r) r = (unsigned int)lrand48();
270 #else
271 while (q <= r) r = (unsigned int)rand();
272 #endif
273 }
274 return r % n;
275 }
276
277
libxsmm_rng_seq(void * data,libxsmm_blasint nbytes)278 LIBXSMM_API void libxsmm_rng_seq(void* data, libxsmm_blasint nbytes)
279 {
280 unsigned char* dst = (unsigned char*)data;
281 unsigned char* end = dst + (nbytes & 0xFFFFFFFFFFFFFFFC);
282 unsigned int r;
283 for (; dst < end; dst += 4) {
284 #if defined(LIBXSMM_RNG_DRAND48)
285 /* coverity[dont_call] */
286 r = (unsigned int)lrand48();
287 #else
288 r = (unsigned int)rand();
289 #endif
290 LIBXSMM_MEMCPY127(dst, &r, 4);
291 }
292 end = (unsigned char*)data + nbytes;
293 if (dst < end) {
294 #if defined(LIBXSMM_RNG_DRAND48)
295 r = (unsigned int)lrand48();
296 #else
297 r = (unsigned int)rand();
298 #endif
299 LIBXSMM_MEMCPY127(dst, &r, end - dst);
300 }
301 }
302
303
libxsmm_rng_f64(void)304 LIBXSMM_API double libxsmm_rng_f64(void)
305 {
306 #if defined(LIBXSMM_RNG_DRAND48)
307 /* coverity[dont_call] */
308 return drand48();
309 #else
310 static const double scale = 1.0 / (RAND_MAX);
311 return scale * (double)rand();
312 #endif
313 }
314
315