1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Alexander Heinecke, Hans Pabst (Intel Corp.)
10 ******************************************************************************/
11 #include "libxsmm_rng.h"
12 #include "libxsmm_main.h"
13 
14 #if !defined(LIBXSMM_RNG_DRAND48) && (!defined(_WIN32) && !defined(__CYGWIN__) && (defined(_SVID_SOURCE) || defined(_XOPEN_SOURCE)))
15 # define LIBXSMM_RNG_DRAND48
16 #endif
17 
18 #if !defined(LIBXSMM_RNG_SIMD_MIN)
19 # define LIBXSMM_RNG_SIMD_MIN 8
20 #endif
21 
22 /* dispatched RNG functions (separate typedef for legacy Cray C++ needed) */
23 typedef void (*internal_rng_f32_seq_fn)(float*, libxsmm_blasint);
24 LIBXSMM_APIVAR_DEFINE(internal_rng_f32_seq_fn internal_rng_f32_seq);
25 /* 2048-bit state for RNG */
26 LIBXSMM_APIVAR_DEFINE(unsigned int internal_rng_state0[16]);
27 LIBXSMM_APIVAR_DEFINE(unsigned int internal_rng_state1[16]);
28 LIBXSMM_APIVAR_DEFINE(unsigned int internal_rng_state2[16]);
29 LIBXSMM_APIVAR_DEFINE(unsigned int internal_rng_state3[16]);
30 
31 
internal_rng_float_jump(uint32_t * state0,uint32_t * state1,uint32_t * state2,uint32_t * state3)32 LIBXSMM_API_INLINE void internal_rng_float_jump(uint32_t* state0, uint32_t* state1, uint32_t* state2, uint32_t* state3)
33 {
34   static const uint32_t jump_table[] = { 0x8764000b, 0xf542d2d3, 0x6fa035c3, 0x77f2db5b };
35   uint32_t s0 = 0, s1 = 0, s2 = 0, s3 = 0;
36   int i, b;
37 
38   LIBXSMM_ASSERT(4 == sizeof(jump_table) / sizeof(*jump_table));
39   for (i = 0; i < 4; ++i) {
40     for (b = 0; b < 32; ++b) {
41       if (jump_table[i] & (1U << b)) {
42         s0 ^= *state0;
43         s1 ^= *state1;
44         s2 ^= *state2;
45         s3 ^= *state3;
46       }
47       { /* draw one more integer */
48         const uint32_t t = *state1 << 9;
49         *state2 ^= *state0;
50         *state3 ^= *state1;
51         *state1 ^= *state2;
52         *state0 ^= *state3;
53         *state2 ^= t;
54         *state3 = ((*state3 << 11) | (*state3 >> (32 - 11)));
55       }
56     }
57   }
58   *state0 = s0;
59   *state1 = s1;
60   *state2 = s2;
61   *state3 = s3;
62 }
63 
64 
internal_rng_scalar_float_next(int i)65 LIBXSMM_API_INLINE float internal_rng_scalar_float_next(int i)
66 {
67   const uint32_t rng_mantissa = (internal_rng_state0[i] + internal_rng_state3[i]) >> 9;
68   const uint32_t t = internal_rng_state1[i] << 9;
69   union { uint32_t i; float f; } rng;
70 
71   internal_rng_state2[i] ^= internal_rng_state0[i];
72   internal_rng_state3[i] ^= internal_rng_state1[i];
73   internal_rng_state1[i] ^= internal_rng_state2[i];
74   internal_rng_state0[i] ^= internal_rng_state3[i];
75   internal_rng_state2[i] ^= t;
76   internal_rng_state3[i] = ((internal_rng_state3[i] << 11) | (internal_rng_state3[i] >> (32 - 11)));
77 
78   rng.i = 0x3f800000 | rng_mantissa;
79   return rng.f - 1.0f;
80 }
81 
82 
83 LIBXSMM_API_INTERN void internal_rng_set_seed_sw(uint32_t seed);
internal_rng_set_seed_sw(uint32_t seed)84 LIBXSMM_API_INTERN void internal_rng_set_seed_sw(uint32_t seed)
85 {
86   static const uint32_t temp_state[] = {
87      31,  30,  29,  28,  27,  26,  25,  24,  23,  22,  21,  20,  19,  18,  17,  16,
88     131, 130, 129, 128, 127, 126, 125, 124, 123, 122, 121, 120, 119, 118, 117, 116,
89     231, 230, 229, 228, 227, 226, 225, 224, 223, 222, 221, 220, 219, 218, 217, 216,
90     331, 330, 329, 328, 327, 326, 325, 324, 323, 322, 321, 320, 319, 318, 317, 316
91   };
92   libxsmm_blasint i;
93 
94   /* finish initializing the state */
95   LIBXSMM_ASSERT((16 * 4) == sizeof(temp_state) / sizeof(*temp_state));
96   for (i = 0; i < 16; ++i) {
97     internal_rng_state0[i] = seed + temp_state[i];
98     internal_rng_state1[i] = seed + temp_state[i+16];
99     internal_rng_state2[i] = seed + temp_state[i+32];
100     internal_rng_state3[i] = seed + temp_state[i+48];
101   }
102   for (i = 0; i < 16; ++i) {
103     internal_rng_float_jump( /* progress each sequence by 2^64 */
104       internal_rng_state0 + i, internal_rng_state1 + i,
105       internal_rng_state2 + i, internal_rng_state3 + i);
106   }
107   /* for consistency, other RNGs are seeded as well */
108 #if !defined(_WIN32) && !defined(__CYGWIN__) && (defined(_SVID_SOURCE) || defined(_XOPEN_SOURCE))
109   srand48(seed);
110 #endif
111   srand(seed);
112 }
113 
114 
internal_rng_f32_seq_sw(float * rngs,libxsmm_blasint count)115 LIBXSMM_API_INLINE void internal_rng_f32_seq_sw(float* rngs, libxsmm_blasint count)
116 {
117   libxsmm_blasint i = 0;
118   for (; i < count; ++i) {
119     rngs[i] = internal_rng_scalar_float_next(LIBXSMM_MOD2(i, 16));
120   }
121 }
122 
123 
124 #if defined(LIBXSMM_INTRINSICS_AVX512) /* __AVX512F__ */
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)125 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
126 void internal_rng_set_seed_avx512(uint32_t seed)
127 {
128   internal_rng_set_seed_sw(seed);
129   /* bring scalar state to AVX-512 */
130   LIBXSMM_INTRINSICS_MM512_RNG_STATE(0) = _mm512_loadu_si512(internal_rng_state0);
131   LIBXSMM_INTRINSICS_MM512_RNG_STATE(1) = _mm512_loadu_si512(internal_rng_state1);
132   LIBXSMM_INTRINSICS_MM512_RNG_STATE(2) = _mm512_loadu_si512(internal_rng_state2);
133   LIBXSMM_INTRINSICS_MM512_RNG_STATE(3) = _mm512_loadu_si512(internal_rng_state3);
134 }
135 
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)136 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
137 void internal_rng_f32_seq_avx512(float* rngs, libxsmm_blasint count)
138 {
139   if ((LIBXSMM_RNG_SIMD_MIN << 4) <= count) { /* SIMD code path */
140     const libxsmm_blasint n = (count >> 4) << 4; /* multiple of vector-length */
141     libxsmm_blasint i = 0;
142     for (; i < n; i += 16) {
143       _mm512_storeu_ps(rngs + i, LIBXSMM_INTRINSICS_MM512_RNG_PS());
144     }
145     if (i < count) { /* remainder */
146 #if 0 /* assert(0 < n) */
147       if (0 < n)
148 #endif
149       { /* bring AVX-512 state to scalar */
150         _mm512_storeu_si512(internal_rng_state0, LIBXSMM_INTRINSICS_MM512_RNG_STATE(0));
151         _mm512_storeu_si512(internal_rng_state1, LIBXSMM_INTRINSICS_MM512_RNG_STATE(1));
152         _mm512_storeu_si512(internal_rng_state2, LIBXSMM_INTRINSICS_MM512_RNG_STATE(2));
153         _mm512_storeu_si512(internal_rng_state3, LIBXSMM_INTRINSICS_MM512_RNG_STATE(3));
154       }
155       LIBXSMM_ASSERT(count < i + 16);
156       do { /* scalar remainder */
157         rngs[i] = internal_rng_scalar_float_next(LIBXSMM_MOD2(i, 16));
158         ++i;
159       } while (i < count);
160       /* bring scalar state to AVX-512 */
161       LIBXSMM_INTRINSICS_MM512_RNG_STATE(0) = _mm512_loadu_si512(internal_rng_state0);
162       LIBXSMM_INTRINSICS_MM512_RNG_STATE(1) = _mm512_loadu_si512(internal_rng_state1);
163       LIBXSMM_INTRINSICS_MM512_RNG_STATE(2) = _mm512_loadu_si512(internal_rng_state2);
164       LIBXSMM_INTRINSICS_MM512_RNG_STATE(3) = _mm512_loadu_si512(internal_rng_state3);
165     }
166   }
167   else { /* scalar code path */
168     internal_rng_f32_seq_sw(rngs, count);
169   }
170 }
171 #endif /*defined(LIBXSMM_INTRINSICS_AVX512)*/
172 
173 
libxsmm_rng_create_avx512_extstate(unsigned int seed)174 LIBXSMM_API unsigned int* libxsmm_rng_create_avx512_extstate(unsigned int/*uint32_t*/ seed)
175 {
176   unsigned int* state = (unsigned int*) libxsmm_aligned_malloc( 64*sizeof(unsigned int), 64 );
177   static const uint32_t temp_state[] = {
178      31,  30,  29,  28,  27,  26,  25,  24,  23,  22,  21,  20,  19,  18,  17,  16,
179     131, 130, 129, 128, 127, 126, 125, 124, 123, 122, 121, 120, 119, 118, 117, 116,
180     231, 230, 229, 228, 227, 226, 225, 224, 223, 222, 221, 220, 219, 218, 217, 216,
181     331, 330, 329, 328, 327, 326, 325, 324, 323, 322, 321, 320, 319, 318, 317, 316
182   };
183   libxsmm_blasint i;
184 
185   /* finish initializing the state */
186   LIBXSMM_ASSERT((16 * 4) == sizeof(temp_state) / sizeof(*temp_state));
187   for (i = 0; i < 16; ++i) {
188     state[i   ] = seed + temp_state[i];
189     state[i+16] = seed + temp_state[i+16];
190     state[i+32] = seed + temp_state[i+32];
191     state[i+48] = seed + temp_state[i+48];
192   }
193   for (i = 0; i < 16; ++i) {
194     internal_rng_float_jump( /* progress each sequence by 2^64 */
195       state +      i, state + 16 + i,
196       state + 32 + i, state + 48 + i);
197   }
198 
199   return state;
200 }
201 
202 
libxsmm_rng_destroy_avx512_extstate(unsigned int * stateptr)203 LIBXSMM_API void libxsmm_rng_destroy_avx512_extstate(unsigned int* stateptr)
204 {
205   if ( stateptr != NULL ) {
206     libxsmm_free( stateptr );
207   }
208 }
209 
210 
libxsmm_rng_set_seed(unsigned int seed)211 LIBXSMM_API void libxsmm_rng_set_seed(unsigned int/*uint32_t*/ seed)
212 {
213   LIBXSMM_INIT
214 #if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH)
215 # if !defined(NDEBUG) /* used to track if seed is initialized */
216   internal_rng_f32_seq = internal_rng_f32_seq_avx512;
217 # endif
218   internal_rng_set_seed_avx512(seed);
219 #elif defined(LIBXSMM_INTRINSICS_AVX512) /* __AVX512F__ */
220   if (LIBXSMM_X86_AVX512 <= libxsmm_target_archid) {
221     internal_rng_f32_seq = internal_rng_f32_seq_avx512;
222     internal_rng_set_seed_avx512(seed);
223   }
224   else {
225     internal_rng_f32_seq = internal_rng_f32_seq_sw;
226     internal_rng_set_seed_sw(seed);
227   }
228 #else
229 # if !defined(NDEBUG) /* used to track if seed is initialized */
230   internal_rng_f32_seq = internal_rng_f32_seq_sw;
231 # endif
232   internal_rng_set_seed_sw(seed);
233 #endif
234 }
235 
236 
libxsmm_rng_f32_seq(float * rngs,libxsmm_blasint count)237 LIBXSMM_API void libxsmm_rng_f32_seq(float* rngs, libxsmm_blasint count)
238 {
239   LIBXSMM_ASSERT_MSG(NULL != internal_rng_f32_seq, "RNG must be initialized");
240 #if (LIBXSMM_X86_AVX512 <= LIBXSMM_STATIC_TARGET_ARCH)
241   internal_rng_f32_seq_avx512(rngs, count);
242 #else
243 # if defined(LIBXSMM_INTRINSICS_AVX512) /* __AVX512F__ */
244   if ((LIBXSMM_RNG_SIMD_MIN << 4) <= count) { /* SIMD code path */
245     internal_rng_f32_seq(rngs, count); /* pointer based function call */
246   }
247   else /* scalar code path */
248 # endif
249   internal_rng_f32_seq_sw(rngs, count);
250 #endif
251 }
252 
253 
libxsmm_rng_u32(unsigned int n)254 LIBXSMM_API unsigned int libxsmm_rng_u32(unsigned int n)
255 {
256 #if defined(LIBXSMM_RNG_DRAND48)
257   const unsigned int q = ((1U << 31) / n) * n;
258   unsigned int r = (unsigned int)lrand48();
259   if (q != (1U << 31))
260 #else
261   const unsigned int rand_max1 = (unsigned int)(RAND_MAX)+1U;
262   const unsigned int q = (rand_max1 / n) * n;
263   unsigned int r = (unsigned int)rand();
264   if (q != rand_max1)
265 #endif
266   {
267 #if defined(LIBXSMM_RNG_DRAND48)
268     /* coverity[dont_call] */
269     while (q <= r) r = (unsigned int)lrand48();
270 #else
271     while (q <= r) r = (unsigned int)rand();
272 #endif
273   }
274   return r % n;
275 }
276 
277 
libxsmm_rng_seq(void * data,libxsmm_blasint nbytes)278 LIBXSMM_API void libxsmm_rng_seq(void* data, libxsmm_blasint nbytes)
279 {
280   unsigned char* dst = (unsigned char*)data;
281   unsigned char* end = dst + (nbytes & 0xFFFFFFFFFFFFFFFC);
282   unsigned int r;
283   for (; dst < end; dst += 4) {
284 #if defined(LIBXSMM_RNG_DRAND48)
285     /* coverity[dont_call] */
286     r = (unsigned int)lrand48();
287 #else
288     r = (unsigned int)rand();
289 #endif
290     LIBXSMM_MEMCPY127(dst, &r, 4);
291   }
292   end = (unsigned char*)data + nbytes;
293   if (dst < end) {
294 #if defined(LIBXSMM_RNG_DRAND48)
295     r = (unsigned int)lrand48();
296 #else
297     r = (unsigned int)rand();
298 #endif
299     LIBXSMM_MEMCPY127(dst, &r, end - dst);
300   }
301 }
302 
303 
libxsmm_rng_f64(void)304 LIBXSMM_API double libxsmm_rng_f64(void)
305 {
306 #if defined(LIBXSMM_RNG_DRAND48)
307   /* coverity[dont_call] */
308   return drand48();
309 #else
310   static const double scale = 1.0 / (RAND_MAX);
311   return scale * (double)rand();
312 #endif
313 }
314 
315