1 #include <stdint.h>
2 #include <string.h>
3 
4 #include "address.h"
5 #include "hash.h"
6 #include "hash_state.h"
7 #include "hashx8.h"
8 #include "params.h"
9 #include "thash.h"
10 #include "thashx8.h"
11 #include "utils.h"
12 #include "wots.h"
13 
14 // TODO clarify address expectations, and make them more uniform.
15 // TODO i.e. do we expect types to be set already?
16 // TODO and do we expect modifications or copies?
17 
18 /**
19  * Computes the starting value for a chain, i.e. the secret key.
20  * Expects the address to be complete up to the chain address.
21  */
wots_gen_sk(unsigned char * sk,const unsigned char * sk_seed,uint32_t wots_addr[8],const hash_state * state_seeded)22 static void wots_gen_sk(unsigned char *sk, const unsigned char *sk_seed,
23                         uint32_t wots_addr[8], const hash_state *state_seeded) {
24     /* Make sure that the hash address is actually zeroed. */
25     PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_set_hash_addr(wots_addr, 0);
26 
27     /* Generate sk element. */
28     PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_prf_addr(sk, sk_seed, wots_addr, state_seeded);
29 }
30 
31 /**
32  * 8-way parallel version of wots_gen_sk; expects 8x as much space in sk
33  */
wots_gen_skx8(unsigned char * skx8,const unsigned char * sk_seed,uint32_t wots_addrx8[8* 8])34 static void wots_gen_skx8(unsigned char *skx8, const unsigned char *sk_seed,
35                           uint32_t wots_addrx8[8 * 8]) {
36     unsigned int j;
37 
38     /* Make sure that the hash address is actually zeroed. */
39     for (j = 0; j < 8; j++) {
40         PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_set_hash_addr(wots_addrx8 + j * 8, 0);
41     }
42 
43     /* Generate sk element. */
44     PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_prf_addrx8(skx8 + 0 * PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_N,
45             skx8 + 1 * PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_N,
46             skx8 + 2 * PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_N,
47             skx8 + 3 * PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_N,
48             skx8 + 4 * PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_N,
49             skx8 + 5 * PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_N,
50             skx8 + 6 * PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_N,
51             skx8 + 7 * PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_N,
52             sk_seed, wots_addrx8);
53 }
54 
55 /**
56  * Computes the chaining function.
57  * out and in have to be n-byte arrays.
58  *
59  * Interprets in as start-th value of the chain.
60  * addr has to contain the address of the chain.
61  */
gen_chain(unsigned char * out,const unsigned char * in,unsigned int start,unsigned int steps,const unsigned char * pub_seed,uint32_t addr[8],const hash_state * state_seeded)62 static void gen_chain(unsigned char *out, const unsigned char *in,
63                       unsigned int start, unsigned int steps,
64                       const unsigned char *pub_seed, uint32_t addr[8],
65                       const hash_state *state_seeded) {
66     uint32_t i;
67 
68     /* Initialize out with the value at position 'start'. */
69     memcpy(out, in, PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_N);
70 
71     /* Iterate 'steps' calls to the hash function. */
72     for (i = start; i < (start + steps) && i < PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_WOTS_W; i++) {
73         PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_set_hash_addr(addr, i);
74         PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_thash_1(out, out, pub_seed, addr, state_seeded);
75     }
76 }
77 
78 /**
79  * 8-way parallel version of gen_chain; expects 8x as much space in out, and
80  * 8x as much space in inx8. Assumes start and step identical across chains.
81  */
gen_chainx8(unsigned char * outx8,const unsigned char * inx8,unsigned int start,unsigned int steps,const unsigned char * pub_seed,uint32_t addrx8[8* 8],const hash_state * state_seeded)82 static void gen_chainx8(unsigned char *outx8, const unsigned char *inx8,
83                         unsigned int start, unsigned int steps,
84                         const unsigned char *pub_seed, uint32_t addrx8[8 * 8],
85                         const hash_state *state_seeded) {
86     uint32_t i;
87     unsigned int j;
88 
89     /* Initialize outx8 with the value at position 'start'. */
90     memcpy(outx8, inx8, 8 * PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_N);
91 
92     /* Iterate 'steps' calls to the hash function. */
93     for (i = start; i < (start + steps) && i < PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_WOTS_W; i++) {
94         for (j = 0; j < 8; j++) {
95             PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_set_hash_addr(addrx8 + j * 8, i);
96         }
97         PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_thashx8_1(outx8 + 0 * PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_N,
98                 outx8 + 1 * PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_N,
99                 outx8 + 2 * PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_N,
100                 outx8 + 3 * PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_N,
101                 outx8 + 4 * PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_N,
102                 outx8 + 5 * PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_N,
103                 outx8 + 6 * PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_N,
104                 outx8 + 7 * PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_N,
105                 outx8 + 0 * PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_N,
106                 outx8 + 1 * PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_N,
107                 outx8 + 2 * PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_N,
108                 outx8 + 3 * PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_N,
109                 outx8 + 4 * PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_N,
110                 outx8 + 5 * PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_N,
111                 outx8 + 6 * PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_N,
112                 outx8 + 7 * PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_N,
113                 pub_seed, addrx8, state_seeded);
114     }
115 }
116 
117 /**
118  * base_w algorithm as described in draft.
119  * Interprets an array of bytes as integers in base w.
120  * This only works when log_w is a divisor of 8.
121  */
base_w(unsigned int * output,const int out_len,const unsigned char * input)122 static void base_w(unsigned int *output, const int out_len, const unsigned char *input) {
123     int in = 0;
124     int out = 0;
125     unsigned char total = 0;
126     int bits = 0;
127     int consumed;
128 
129     for (consumed = 0; consumed < out_len; consumed++) {
130         if (bits == 0) {
131             total = input[in];
132             in++;
133             bits += 8;
134         }
135         bits -= PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_WOTS_LOGW;
136         output[out] = (unsigned int)(total >> bits) & (PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_WOTS_W - 1);
137         out++;
138     }
139 }
140 
141 /* Computes the WOTS+ checksum over a message (in base_w). */
wots_checksum(unsigned int * csum_base_w,const unsigned int * msg_base_w)142 static void wots_checksum(unsigned int *csum_base_w, const unsigned int *msg_base_w) {
143     unsigned int csum = 0;
144     unsigned char csum_bytes[(PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_WOTS_LEN2 * PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_WOTS_LOGW + 7) / 8];
145     unsigned int i;
146 
147     /* Compute checksum. */
148     for (i = 0; i < PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_WOTS_LEN1; i++) {
149         csum += PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_WOTS_W - 1 - msg_base_w[i];
150     }
151 
152     /* Convert checksum to base_w. */
153     /* Make sure expected empty zero bits are the least significant bits. */
154     csum = csum << (8 - ((PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_WOTS_LEN2 * PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_WOTS_LOGW) % 8));
155     PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_ull_to_bytes(csum_bytes, sizeof(csum_bytes), csum);
156     base_w(csum_base_w, PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_WOTS_LEN2, csum_bytes);
157 }
158 
159 /* Takes a message and derives the matching chain lengths. */
chain_lengths(unsigned int * lengths,const unsigned char * msg)160 static void chain_lengths(unsigned int *lengths, const unsigned char *msg) {
161     base_w(lengths, PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_WOTS_LEN1, msg);
162     wots_checksum(lengths + PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_WOTS_LEN1, lengths);
163 }
164 
165 /**
166  * WOTS key generation. Takes a 32 byte sk_seed, expands it to WOTS private key
167  * elements and computes the corresponding public key.
168  * It requires the seed pub_seed (used to generate bitmasks and hash keys)
169  * and the address of this WOTS key pair.
170  *
171  * Writes the computed public key to 'pk'.
172  */
PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_wots_gen_pk(unsigned char * pk,const unsigned char * sk_seed,const unsigned char * pub_seed,uint32_t addr[8],const hash_state * state_seeded)173 void PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_wots_gen_pk(unsigned char *pk, const unsigned char *sk_seed,
174         const unsigned char *pub_seed, uint32_t addr[8],
175         const hash_state *state_seeded) {
176     uint32_t i;
177     unsigned int j;
178 
179     uint32_t addrx8[8 * 8];
180     unsigned char pkbuf[8 * PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_N];
181 
182     for (j = 0; j < 8; j++) {
183         memcpy(addrx8 + j * 8, addr, sizeof(uint32_t) * 8);
184     }
185 
186     /* The last iteration typically does not have complete set of 4 chains,
187        but because we use pkbuf, this is not an issue -- we still do as many
188        in parallel as possible. */
189     for (i = 0; i < ((PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_WOTS_LEN + 7) & ~0x7); i += 8) {
190         for (j = 0; j < 8; j++) {
191             PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_set_chain_addr(addrx8 + j * 8, i + j);
192         }
193         wots_gen_skx8(pkbuf, sk_seed, addrx8);
194         gen_chainx8(pkbuf, pkbuf, 0, PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_WOTS_W - 1, pub_seed, addrx8, state_seeded);
195         for (j = 0; j < 8; j++) {
196             if (i + j < PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_WOTS_LEN) {
197                 memcpy(pk + (i + j)*PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_N, pkbuf + j * PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_N, PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_N);
198             }
199         }
200     }
201 }
202 
203 /**
204  * Takes a n-byte message and the 32-byte sk_see to compute a signature 'sig'.
205  */
PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_wots_sign(unsigned char * sig,const unsigned char * msg,const unsigned char * sk_seed,const unsigned char * pub_seed,uint32_t addr[8],const hash_state * state_seeded)206 void PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_wots_sign(unsigned char *sig, const unsigned char *msg,
207         const unsigned char *sk_seed, const unsigned char *pub_seed,
208         uint32_t addr[8], const hash_state *state_seeded) {
209     unsigned int lengths[PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_WOTS_LEN];
210     uint32_t i;
211 
212     chain_lengths(lengths, msg);
213 
214     for (i = 0; i < PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_WOTS_LEN; i++) {
215         PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_set_chain_addr(addr, i);
216         wots_gen_sk(sig + i * PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_N, sk_seed, addr, state_seeded);
217         gen_chain(sig + i * PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_N, sig + i * PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_N, 0, lengths[i], pub_seed, addr, state_seeded);
218     }
219 }
220 
221 /**
222  * Takes a WOTS signature and an n-byte message, computes a WOTS public key.
223  *
224  * Writes the computed public key to 'pk'.
225  */
PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_wots_pk_from_sig(unsigned char * pk,const unsigned char * sig,const unsigned char * msg,const unsigned char * pub_seed,uint32_t addr[8],const hash_state * state_seeded)226 void PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_wots_pk_from_sig(unsigned char *pk,
227         const unsigned char *sig, const unsigned char *msg,
228         const unsigned char *pub_seed, uint32_t addr[8],
229         const hash_state *state_seeded) {
230     unsigned int lengths[PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_WOTS_LEN];
231     uint32_t i;
232 
233     chain_lengths(lengths, msg);
234 
235     for (i = 0; i < PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_WOTS_LEN; i++) {
236         PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_set_chain_addr(addr, i);
237         gen_chain(pk + i * PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_N, sig + i * PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_N,
238                   lengths[i], PQCLEAN_SPHINCSSHA256192FSIMPLE_AVX2_WOTS_W - 1 - lengths[i], pub_seed, addr, state_seeded);
239     }
240 }
241