1 /* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2  * SPDX-License-Identifier: Apache-2.0"
3  *
4  * Written by Nir Drucker, Shay Gueron and Dusan Kostic,
5  * AWS Cryptographic Algorithms Group.
6  */
7 
8 #include <assert.h>
9 
10 #include "sampling.h"
11 #include "sampling_internal.h"
12 
13 // SIMD implementation of is_new function requires the size of wlist
14 // to be a multiple of the number of DWORDS in a SIMD register (REG_DWORDS).
15 // The function is used both for generating DV and T1 random numbers so we define
16 // two separate macros.
17 #define AVX512_REG_DWORDS (16)
18 #define WLIST_SIZE_ADJUSTED_D \
19   (AVX512_REG_DWORDS * DIVIDE_AND_CEIL(DV, AVX512_REG_DWORDS))
20 #define WLIST_SIZE_ADJUSTED_T \
21   (AVX512_REG_DWORDS * DIVIDE_AND_CEIL(T1, AVX512_REG_DWORDS))
22 
23 // BSR returns ceil(log2(val))
bit_scan_reverse_vartime(IN uint64_t val)24 _INLINE_ uint8_t bit_scan_reverse_vartime(IN uint64_t val)
25 {
26   // index is always smaller than 64
27   uint8_t index = 0;
28 
29   while(val != 0) {
30     val >>= 1;
31     index++;
32   }
33 
34   return index;
35 }
36 
get_rand_mod_len(OUT uint32_t * rand_pos,IN const uint32_t len,IN OUT aes_ctr_prf_state_t * prf_state)37 _INLINE_ ret_t get_rand_mod_len(OUT uint32_t *    rand_pos,
38                                 IN const uint32_t len,
39                                 IN OUT aes_ctr_prf_state_t *prf_state)
40 {
41   const uint64_t mask = MASK(bit_scan_reverse_vartime(len));
42 
43   do {
44     // Generate a 32 bits (pseudo) random value.
45     // This can be optimized to take only 16 bits.
46     POSIX_GUARD(aes_ctr_prf((uint8_t *)rand_pos, prf_state, sizeof(*rand_pos)));
47 
48     // Mask relevant bits only
49     (*rand_pos) &= mask;
50 
51     // Break if a number that is smaller than len is found
52     if((*rand_pos) < len) {
53       break;
54     }
55 
56   } while(1 == 1);
57 
58   return SUCCESS;
59 }
60 
make_odd_weight(IN OUT r_t * r)61 _INLINE_ void make_odd_weight(IN OUT r_t *r)
62 {
63   if(((r_bits_vector_weight(r) % 2) == 1)) {
64     // Already odd
65     return;
66   }
67 
68   r->raw[0] ^= 1;
69 }
70 
71 // Returns an array of r pseudorandom bits.
72 // No restrictions exist for the top or bottom bits.
73 // If the generation requires an odd number, then set must_be_odd=1.
74 // The function uses the provided prf context.
sample_uniform_r_bits_with_fixed_prf_context(OUT r_t * r,IN OUT aes_ctr_prf_state_t * prf_state,IN const must_be_odd_t must_be_odd)75 ret_t sample_uniform_r_bits_with_fixed_prf_context(
76   OUT r_t *r,
77   IN OUT aes_ctr_prf_state_t *prf_state,
78   IN const must_be_odd_t      must_be_odd)
79 {
80   // Generate random data
81   POSIX_GUARD(aes_ctr_prf(r->raw, prf_state, R_BYTES));
82 
83   // Mask upper bits of the MSByte
84   r->raw[R_BYTES - 1] &= MASK(R_BITS + 8 - (R_BYTES * 8));
85 
86   if(must_be_odd == MUST_BE_ODD) {
87     make_odd_weight(r);
88   }
89 
90   return SUCCESS;
91 }
92 
generate_indices_mod_z(OUT idx_t * out,IN const size_t num_indices,IN const size_t z,IN OUT aes_ctr_prf_state_t * prf_state,IN const sampling_ctx * ctx)93 ret_t generate_indices_mod_z(OUT idx_t *     out,
94                              IN const size_t num_indices,
95                              IN const size_t z,
96                              IN OUT aes_ctr_prf_state_t *prf_state,
97                              IN const sampling_ctx *ctx)
98 {
99   size_t ctr = 0;
100 
101   // Generate num_indices unique (pseudo) random numbers modulo z
102   do {
103     POSIX_GUARD(get_rand_mod_len(&out[ctr], z, prf_state));
104     ctr += ctx->is_new(out, ctr);
105   } while(ctr < num_indices);
106 
107   return SUCCESS;
108 }
109 
110 // Returns an array of r pseudorandom bits.
111 // No restrictions exist for the top or bottom bits.
112 // If the generation requires an odd number, then set must_be_odd = MUST_BE_ODD
sample_uniform_r_bits(OUT r_t * r,IN const seed_t * seed,IN const must_be_odd_t must_be_odd)113 ret_t sample_uniform_r_bits(OUT r_t *r,
114                             IN const seed_t *      seed,
115                             IN const must_be_odd_t must_be_odd)
116 {
117   // For the seedexpander
118   DEFER_CLEANUP(aes_ctr_prf_state_t prf_state = {0}, aes_ctr_prf_state_cleanup);
119 
120   POSIX_GUARD(init_aes_ctr_prf_state(&prf_state, MAX_AES_INVOKATION, seed));
121 
122   POSIX_GUARD(sample_uniform_r_bits_with_fixed_prf_context(r, &prf_state, must_be_odd));
123 
124   return SUCCESS;
125 }
126 
generate_sparse_rep(OUT pad_r_t * r,OUT idx_t * wlist,IN OUT aes_ctr_prf_state_t * prf_state)127 ret_t generate_sparse_rep(OUT pad_r_t *r,
128                           OUT idx_t *wlist,
129                           IN OUT aes_ctr_prf_state_t *prf_state)
130 {
131 
132   // Initialize the sampling context
133   sampling_ctx ctx;
134   sampling_ctx_init(&ctx);
135 
136   idx_t wlist_temp[WLIST_SIZE_ADJUSTED_D] = {0};
137 
138   POSIX_GUARD(generate_indices_mod_z(wlist_temp, DV, R_BITS, prf_state, &ctx));
139 
140   bike_memcpy(wlist, wlist_temp, DV * sizeof(idx_t));
141   ctx.secure_set_bits(r, 0, wlist, DV);
142 
143   return SUCCESS;
144 }
145 
generate_error_vector(OUT pad_e_t * e,IN const seed_t * seed)146 ret_t generate_error_vector(OUT pad_e_t *e, IN const seed_t *seed)
147 {
148   DEFER_CLEANUP(aes_ctr_prf_state_t prf_state = {0}, aes_ctr_prf_state_cleanup);
149 
150   POSIX_GUARD(init_aes_ctr_prf_state(&prf_state, MAX_AES_INVOKATION, seed));
151 
152   // Initialize the sampling context
153   sampling_ctx ctx;
154   sampling_ctx_init(&ctx);
155 
156   idx_t wlist[WLIST_SIZE_ADJUSTED_T] = {0};
157   POSIX_GUARD(generate_indices_mod_z(wlist, T1, N_BITS, &prf_state, &ctx));
158 
159   // (e0, e1) hold bits 0..R_BITS-1 and R_BITS..2*R_BITS-1 of the error, resp.
160   ctx.secure_set_bits(&e->val[0], 0, wlist, T1);
161   ctx.secure_set_bits(&e->val[1], R_BITS, wlist, T1);
162 
163   // Clean the padding of the elements
164   PE0_RAW(e)[R_BYTES - 1] &= LAST_R_BYTE_MASK;
165   PE1_RAW(e)[R_BYTES - 1] &= LAST_R_BYTE_MASK;
166   bike_memset(&PE0_RAW(e)[R_BYTES], 0, R_PADDED_BYTES - R_BYTES);
167   bike_memset(&PE1_RAW(e)[R_BYTES], 0, R_PADDED_BYTES - R_BYTES);
168 
169   return SUCCESS;
170 }
171