1 #include <stdint.h>
2 #include <string.h>
3 
4 #include "address.h"
5 #include "hash.h"
6 #include "hash_state.h"
7 #include "hashx4.h"
8 #include "params.h"
9 #include "thash.h"
10 #include "thashx4.h"
11 #include "utils.h"
12 #include "wots.h"
13 
14 // TODO clarify address expectations, and make them more uniform.
15 // TODO i.e. do we expect types to be set already?
16 // TODO and do we expect modifications or copies?
17 
18 /**
19  * Computes the starting value for a chain, i.e. the secret key.
20  * Expects the address to be complete up to the chain address.
21  */
wots_gen_sk(unsigned char * sk,const unsigned char * sk_seed,uint32_t wots_addr[8],const hash_state * state_seeded)22 static void wots_gen_sk(unsigned char *sk, const unsigned char *sk_seed,
23                         uint32_t wots_addr[8], const hash_state *state_seeded) {
24     /* Make sure that the hash address is actually zeroed. */
25     PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_set_hash_addr(wots_addr, 0);
26 
27     /* Generate sk element. */
28     PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_prf_addr(sk, sk_seed, wots_addr, state_seeded);
29 }
30 
31 /**
32  * 4-way parallel version of wots_gen_sk; expects 4x as much space in sk
33  */
wots_gen_skx4(unsigned char * skx4,const unsigned char * sk_seed,uint32_t wots_addrx4[4* 8],const hash_state * state_seeded)34 static void wots_gen_skx4(unsigned char *skx4, const unsigned char *sk_seed,
35                           uint32_t wots_addrx4[4 * 8], const hash_state *state_seeded) {
36     unsigned int j;
37 
38     /* Make sure that the hash address is actually zeroed. */
39     for (j = 0; j < 4; j++) {
40         PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_set_hash_addr(wots_addrx4 + j * 8, 0);
41     }
42 
43     /* Generate sk element. */
44     PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_prf_addrx4(skx4 + 0 * PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_N,
45             skx4 + 1 * PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_N,
46             skx4 + 2 * PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_N,
47             skx4 + 3 * PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_N,
48             sk_seed, wots_addrx4,
49             state_seeded);
50 }
51 
52 /**
53  * Computes the chaining function.
54  * out and in have to be n-byte arrays.
55  *
56  * Interprets in as start-th value of the chain.
57  * addr has to contain the address of the chain.
58  */
gen_chain(unsigned char * out,const unsigned char * in,unsigned int start,unsigned int steps,const unsigned char * pub_seed,uint32_t addr[8],const hash_state * state_seeded)59 static void gen_chain(unsigned char *out, const unsigned char *in,
60                       unsigned int start, unsigned int steps,
61                       const unsigned char *pub_seed, uint32_t addr[8],
62                       const hash_state *state_seeded) {
63     uint32_t i;
64 
65     /* Initialize out with the value at position 'start'. */
66     memcpy(out, in, PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_N);
67 
68     /* Iterate 'steps' calls to the hash function. */
69     for (i = start; i < (start + steps) && i < PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_WOTS_W; i++) {
70         PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_set_hash_addr(addr, i);
71         PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_thash_1(out, out, pub_seed, addr, state_seeded);
72     }
73 }
74 
75 /**
76  * 4-way parallel version of gen_chain; expects 4x as much space in out, and
77  * 4x as much space in inx4. Assumes start and step identical across chains.
78  */
gen_chainx4(unsigned char * outx4,const unsigned char * inx4,unsigned int start,unsigned int steps,const unsigned char * pub_seed,uint32_t addrx4[4* 8],const hash_state * state_seeded)79 static void gen_chainx4(unsigned char *outx4, const unsigned char *inx4,
80                         unsigned int start, unsigned int steps,
81                         const unsigned char *pub_seed, uint32_t addrx4[4 * 8],
82                         const hash_state *state_seeded) {
83     uint32_t i;
84     unsigned int j;
85 
86     /* Initialize outx4 with the value at position 'start'. */
87     memcpy(outx4, inx4, 4 * PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_N);
88 
89     /* Iterate 'steps' calls to the hash function. */
90     for (i = start; i < (start + steps) && i < PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_WOTS_W; i++) {
91         for (j = 0; j < 4; j++) {
92             PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_set_hash_addr(addrx4 + j * 8, i);
93         }
94         PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_thashx4_1(outx4 + 0 * PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_N,
95                 outx4 + 1 * PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_N,
96                 outx4 + 2 * PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_N,
97                 outx4 + 3 * PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_N,
98                 outx4 + 0 * PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_N,
99                 outx4 + 1 * PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_N,
100                 outx4 + 2 * PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_N,
101                 outx4 + 3 * PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_N,
102                 pub_seed, addrx4,
103                 state_seeded);
104     }
105 }
106 
107 /**
108  * base_w algorithm as described in draft.
109  * Interprets an array of bytes as integers in base w.
110  * This only works when log_w is a divisor of 8.
111  */
base_w(unsigned int * output,const int out_len,const unsigned char * input)112 static void base_w(unsigned int *output, const int out_len, const unsigned char *input) {
113     int in = 0;
114     int out = 0;
115     unsigned char total = 0;
116     int bits = 0;
117     int consumed;
118 
119     for (consumed = 0; consumed < out_len; consumed++) {
120         if (bits == 0) {
121             total = input[in];
122             in++;
123             bits += 8;
124         }
125         bits -= PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_WOTS_LOGW;
126         output[out] = (unsigned int)(total >> bits) & (PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_WOTS_W - 1);
127         out++;
128     }
129 }
130 
131 /* Computes the WOTS+ checksum over a message (in base_w). */
wots_checksum(unsigned int * csum_base_w,const unsigned int * msg_base_w)132 static void wots_checksum(unsigned int *csum_base_w, const unsigned int *msg_base_w) {
133     unsigned int csum = 0;
134     unsigned char csum_bytes[(PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_WOTS_LEN2 * PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_WOTS_LOGW + 7) / 8];
135     unsigned int i;
136 
137     /* Compute checksum. */
138     for (i = 0; i < PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_WOTS_LEN1; i++) {
139         csum += PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_WOTS_W - 1 - msg_base_w[i];
140     }
141 
142     /* Convert checksum to base_w. */
143     /* Make sure expected empty zero bits are the least significant bits. */
144     csum = csum << (8 - ((PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_WOTS_LEN2 * PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_WOTS_LOGW) % 8));
145     PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_ull_to_bytes(csum_bytes, sizeof(csum_bytes), csum);
146     base_w(csum_base_w, PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_WOTS_LEN2, csum_bytes);
147 }
148 
149 /* Takes a message and derives the matching chain lengths. */
chain_lengths(unsigned int * lengths,const unsigned char * msg)150 static void chain_lengths(unsigned int *lengths, const unsigned char *msg) {
151     base_w(lengths, PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_WOTS_LEN1, msg);
152     wots_checksum(lengths + PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_WOTS_LEN1, lengths);
153 }
154 
155 /**
156  * WOTS key generation. Takes a 32 byte sk_seed, expands it to WOTS private key
157  * elements and computes the corresponding public key.
158  * It requires the seed pub_seed (used to generate bitmasks and hash keys)
159  * and the address of this WOTS key pair.
160  *
161  * Writes the computed public key to 'pk'.
162  */
PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_wots_gen_pk(unsigned char * pk,const unsigned char * sk_seed,const unsigned char * pub_seed,uint32_t addr[8],const hash_state * state_seeded)163 void PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_wots_gen_pk(unsigned char *pk, const unsigned char *sk_seed,
164         const unsigned char *pub_seed, uint32_t addr[8],
165         const hash_state *state_seeded) {
166     uint32_t i;
167     unsigned int j;
168 
169     uint32_t addrx4[4 * 8];
170     unsigned char pkbuf[4 * PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_N];
171 
172     for (j = 0; j < 4; j++) {
173         memcpy(addrx4 + j * 8, addr, sizeof(uint32_t) * 8);
174     }
175 
176     /* The last iteration typically does not have complete set of 4 chains,
177        but because we use pkbuf, this is not an issue -- we still do as many
178        in parallel as possible. */
179     for (i = 0; i < ((PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_WOTS_LEN + 3) & ~0x3); i += 4) {
180         for (j = 0; j < 4; j++) {
181             PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_set_chain_addr(addrx4 + j * 8, i + j);
182         }
183         wots_gen_skx4(pkbuf, sk_seed, addrx4, state_seeded);
184         gen_chainx4(pkbuf, pkbuf, 0, PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_WOTS_W - 1, pub_seed, addrx4, state_seeded);
185         for (j = 0; j < 4; j++) {
186             if (i + j < PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_WOTS_LEN) {
187                 memcpy(pk + (i + j)*PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_N, pkbuf + j * PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_N, PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_N);
188             }
189         }
190     }
191 
192     // Get rid of unused argument variable.
193     (void)state_seeded;
194 }
195 
196 /**
197  * Takes a n-byte message and the 32-byte sk_see to compute a signature 'sig'.
198  */
PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_wots_sign(unsigned char * sig,const unsigned char * msg,const unsigned char * sk_seed,const unsigned char * pub_seed,uint32_t addr[8],const hash_state * state_seeded)199 void PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_wots_sign(unsigned char *sig, const unsigned char *msg,
200         const unsigned char *sk_seed, const unsigned char *pub_seed,
201         uint32_t addr[8], const hash_state *state_seeded) {
202     unsigned int lengths[PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_WOTS_LEN];
203     uint32_t i;
204 
205     chain_lengths(lengths, msg);
206 
207     for (i = 0; i < PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_WOTS_LEN; i++) {
208         PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_set_chain_addr(addr, i);
209         wots_gen_sk(sig + i * PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_N, sk_seed, addr, state_seeded);
210         gen_chain(sig + i * PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_N, sig + i * PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_N, 0, lengths[i], pub_seed, addr, state_seeded);
211     }
212 
213     // avoid unused argument
214     (void)state_seeded;
215 }
216 
217 /**
218  * Takes a WOTS signature and an n-byte message, computes a WOTS public key.
219  *
220  * Writes the computed public key to 'pk'.
221  */
PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_wots_pk_from_sig(unsigned char * pk,const unsigned char * sig,const unsigned char * msg,const unsigned char * pub_seed,uint32_t addr[8],const hash_state * state_seeded)222 void PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_wots_pk_from_sig(unsigned char *pk,
223         const unsigned char *sig, const unsigned char *msg,
224         const unsigned char *pub_seed, uint32_t addr[8],
225         const hash_state *state_seeded) {
226     unsigned int lengths[PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_WOTS_LEN];
227     uint32_t i;
228 
229     chain_lengths(lengths, msg);
230 
231     for (i = 0; i < PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_WOTS_LEN; i++) {
232         PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_set_chain_addr(addr, i);
233         gen_chain(pk + i * PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_N, sig + i * PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_N,
234                   lengths[i], PQCLEAN_SPHINCSSHAKE256192FROBUST_AVX2_WOTS_W - 1 - lengths[i], pub_seed, addr,
235                   state_seeded);
236     }
237 
238     // avoid unused argument
239     (void)state_seeded;
240 }
241