1 /* This Source Code Form is subject to the terms of the Mozilla Public
2  * License, v. 2.0. If a copy of the MPL was not distributed with this
3  * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
4 
5 /*
6  * Derived from public domain code by Matthew Dempsky and D. J. Bernstein.
7  */
8 
9 #include "ecl-priv.h"
10 #include "mpi.h"
11 
12 #include <stdint.h>
13 #include <stdio.h>
14 
15 typedef uint32_t elem[32];
16 
17 /*
18  * Add two field elements.
19  * out = a + b
20  */
21 static void
add(elem out,const elem a,const elem b)22 add(elem out, const elem a, const elem b)
23 {
24     uint32_t j;
25     uint32_t u = 0;
26     for (j = 0; j < 31; ++j) {
27         u += a[j] + b[j];
28         out[j] = u & 0xFF;
29         u >>= 8;
30     }
31     u += a[31] + b[31];
32     out[31] = u;
33 }
34 
35 /*
36  * Subtract two field elements.
37  * out = a - b
38  */
39 static void
sub(elem out,const elem a,const elem b)40 sub(elem out, const elem a, const elem b)
41 {
42     uint32_t j;
43     uint32_t u;
44     u = 218;
45     for (j = 0; j < 31; ++j) {
46         u += a[j] + 0xFF00 - b[j];
47         out[j] = u & 0xFF;
48         u >>= 8;
49     }
50     u += a[31] - b[31];
51     out[31] = u;
52 }
53 
54 /*
55  * "Squeeze" an element after multiplication (and square).
56  */
57 static void
squeeze(elem a)58 squeeze(elem a)
59 {
60     uint32_t j;
61     uint32_t u;
62     u = 0;
63     for (j = 0; j < 31; ++j) {
64         u += a[j];
65         a[j] = u & 0xFF;
66         u >>= 8;
67     }
68     u += a[31];
69     a[31] = u & 0x7F;
70     u = 19 * (u >> 7);
71     for (j = 0; j < 31; ++j) {
72         u += a[j];
73         a[j] = u & 0xFF;
74         u >>= 8;
75     }
76     a[31] += u;
77 }
78 
79 static const elem minusp = { 19, 0, 0, 0, 0, 0, 0, 0,
80                              0, 0, 0, 0, 0, 0, 0, 0,
81                              0, 0, 0, 0, 0, 0, 0, 0,
82                              0, 0, 0, 0, 0, 0, 0, 128 };
83 
84 /*
85  * Reduce point a by 2^255-19
86  */
87 static void
reduce(elem a)88 reduce(elem a)
89 {
90     elem aorig;
91     uint32_t j;
92     uint32_t negative;
93 
94     for (j = 0; j < 32; ++j) {
95         aorig[j] = a[j];
96     }
97     add(a, a, minusp);
98     negative = 1 + ~((a[31] >> 7) & 1);
99     for (j = 0; j < 32; ++j) {
100         a[j] ^= negative & (aorig[j] ^ a[j]);
101     }
102 }
103 
104 /*
105  * Multiplication and squeeze
106  * out = a * b
107  */
108 static void
mult(elem out,const elem a,const elem b)109 mult(elem out, const elem a, const elem b)
110 {
111     uint32_t i;
112     uint32_t j;
113     uint32_t u;
114 
115     for (i = 0; i < 32; ++i) {
116         u = 0;
117         for (j = 0; j <= i; ++j) {
118             u += a[j] * b[i - j];
119         }
120         for (j = i + 1; j < 32; ++j) {
121             u += 38 * a[j] * b[i + 32 - j];
122         }
123         out[i] = u;
124     }
125     squeeze(out);
126 }
127 
128 /*
129  * Multiplication
130  * out = 121665 * a
131  */
132 static void
mult121665(elem out,const elem a)133 mult121665(elem out, const elem a)
134 {
135     uint32_t j;
136     uint32_t u;
137 
138     u = 0;
139     for (j = 0; j < 31; ++j) {
140         u += 121665 * a[j];
141         out[j] = u & 0xFF;
142         u >>= 8;
143     }
144     u += 121665 * a[31];
145     out[31] = u & 0x7F;
146     u = 19 * (u >> 7);
147     for (j = 0; j < 31; ++j) {
148         u += out[j];
149         out[j] = u & 0xFF;
150         u >>= 8;
151     }
152     u += out[j];
153     out[j] = u;
154 }
155 
156 /*
157  * Square a and squeeze the result.
158  * out = a * a
159  */
160 static void
square(elem out,const elem a)161 square(elem out, const elem a)
162 {
163     uint32_t i;
164     uint32_t j;
165     uint32_t u;
166 
167     for (i = 0; i < 32; ++i) {
168         u = 0;
169         for (j = 0; j < i - j; ++j) {
170             u += a[j] * a[i - j];
171         }
172         for (j = i + 1; j < i + 32 - j; ++j) {
173             u += 38 * a[j] * a[i + 32 - j];
174         }
175         u *= 2;
176         if ((i & 1) == 0) {
177             u += a[i / 2] * a[i / 2];
178             u += 38 * a[i / 2 + 16] * a[i / 2 + 16];
179         }
180         out[i] = u;
181     }
182     squeeze(out);
183 }
184 
185 /*
186  * Constant time swap between r and s depending on b
187  */
188 static void
cswap(uint32_t p[64],uint32_t q[64],uint32_t b)189 cswap(uint32_t p[64], uint32_t q[64], uint32_t b)
190 {
191     uint32_t j;
192     uint32_t swap = 1 + ~b;
193 
194     for (j = 0; j < 64; ++j) {
195         const uint32_t t = swap & (p[j] ^ q[j]);
196         p[j] ^= t;
197         q[j] ^= t;
198     }
199 }
200 
201 /*
202  * Montgomery ladder
203  */
204 static void
monty(elem x_2_out,elem z_2_out,const elem point,const elem scalar)205 monty(elem x_2_out, elem z_2_out,
206       const elem point, const elem scalar)
207 {
208     uint32_t x_3[64] = { 0 };
209     uint32_t x_2[64] = { 0 };
210     uint32_t a0[64];
211     uint32_t a1[64];
212     uint32_t b0[64];
213     uint32_t b1[64];
214     uint32_t c1[64];
215     uint32_t r[32];
216     uint32_t s[32];
217     uint32_t t[32];
218     uint32_t u[32];
219     uint32_t swap = 0;
220     uint32_t k_t = 0;
221     int j;
222 
223     for (j = 0; j < 32; ++j) {
224         x_3[j] = point[j];
225     }
226     x_3[32] = 1;
227     x_2[0] = 1;
228 
229     for (j = 254; j >= 0; --j) {
230         k_t = (scalar[j >> 3] >> (j & 7)) & 1;
231         swap ^= k_t;
232         cswap(x_2, x_3, swap);
233         swap = k_t;
234         add(a0, x_2, x_2 + 32);
235         sub(a0 + 32, x_2, x_2 + 32);
236         add(a1, x_3, x_3 + 32);
237         sub(a1 + 32, x_3, x_3 + 32);
238         square(b0, a0);
239         square(b0 + 32, a0 + 32);
240         mult(b1, a1, a0 + 32);
241         mult(b1 + 32, a1 + 32, a0);
242         add(c1, b1, b1 + 32);
243         sub(c1 + 32, b1, b1 + 32);
244         square(r, c1 + 32);
245         sub(s, b0, b0 + 32);
246         mult121665(t, s);
247         add(u, t, b0);
248         mult(x_2, b0, b0 + 32);
249         mult(x_2 + 32, s, u);
250         square(x_3, c1);
251         mult(x_3 + 32, r, point);
252     }
253 
254     cswap(x_2, x_3, swap);
255     for (j = 0; j < 32; ++j) {
256         x_2_out[j] = x_2[j];
257     }
258     for (j = 0; j < 32; ++j) {
259         z_2_out[j] = x_2[j + 32];
260     }
261 }
262 
263 static void
recip(elem out,const elem z)264 recip(elem out, const elem z)
265 {
266     elem z2;
267     elem z9;
268     elem z11;
269     elem z2_5_0;
270     elem z2_10_0;
271     elem z2_20_0;
272     elem z2_50_0;
273     elem z2_100_0;
274     elem t0;
275     elem t1;
276     int i;
277 
278     /* 2 */ square(z2, z);
279     /* 4 */ square(t1, z2);
280     /* 8 */ square(t0, t1);
281     /* 9 */ mult(z9, t0, z);
282     /* 11 */ mult(z11, z9, z2);
283     /* 22 */ square(t0, z11);
284     /* 2^5 - 2^0 = 31 */ mult(z2_5_0, t0, z9);
285 
286     /* 2^6 - 2^1 */ square(t0, z2_5_0);
287     /* 2^7 - 2^2 */ square(t1, t0);
288     /* 2^8 - 2^3 */ square(t0, t1);
289     /* 2^9 - 2^4 */ square(t1, t0);
290     /* 2^10 - 2^5 */ square(t0, t1);
291     /* 2^10 - 2^0 */ mult(z2_10_0, t0, z2_5_0);
292 
293     /* 2^11 - 2^1 */ square(t0, z2_10_0);
294     /* 2^12 - 2^2 */ square(t1, t0);
295     /* 2^20 - 2^10 */
296     for (i = 2; i < 10; i += 2) {
297         square(t0, t1);
298         square(t1, t0);
299     }
300     /* 2^20 - 2^0 */ mult(z2_20_0, t1, z2_10_0);
301 
302     /* 2^21 - 2^1 */ square(t0, z2_20_0);
303     /* 2^22 - 2^2 */ square(t1, t0);
304     /* 2^40 - 2^20 */
305     for (i = 2; i < 20; i += 2) {
306         square(t0, t1);
307         square(t1, t0);
308     }
309     /* 2^40 - 2^0 */ mult(t0, t1, z2_20_0);
310 
311     /* 2^41 - 2^1 */ square(t1, t0);
312     /* 2^42 - 2^2 */ square(t0, t1);
313     /* 2^50 - 2^10 */
314     for (i = 2; i < 10; i += 2) {
315         square(t1, t0);
316         square(t0, t1);
317     }
318     /* 2^50 - 2^0 */ mult(z2_50_0, t0, z2_10_0);
319 
320     /* 2^51 - 2^1 */ square(t0, z2_50_0);
321     /* 2^52 - 2^2 */ square(t1, t0);
322     /* 2^100 - 2^50 */
323     for (i = 2; i < 50; i += 2) {
324         square(t0, t1);
325         square(t1, t0);
326     }
327     /* 2^100 - 2^0 */ mult(z2_100_0, t1, z2_50_0);
328 
329     /* 2^101 - 2^1 */ square(t1, z2_100_0);
330     /* 2^102 - 2^2 */ square(t0, t1);
331     /* 2^200 - 2^100 */
332     for (i = 2; i < 100; i += 2) {
333         square(t1, t0);
334         square(t0, t1);
335     }
336     /* 2^200 - 2^0 */ mult(t1, t0, z2_100_0);
337 
338     /* 2^201 - 2^1 */ square(t0, t1);
339     /* 2^202 - 2^2 */ square(t1, t0);
340     /* 2^250 - 2^50 */
341     for (i = 2; i < 50; i += 2) {
342         square(t0, t1);
343         square(t1, t0);
344     }
345     /* 2^250 - 2^0 */ mult(t0, t1, z2_50_0);
346 
347     /* 2^251 - 2^1 */ square(t1, t0);
348     /* 2^252 - 2^2 */ square(t0, t1);
349     /* 2^253 - 2^3 */ square(t1, t0);
350     /* 2^254 - 2^4 */ square(t0, t1);
351     /* 2^255 - 2^5 */ square(t1, t0);
352     /* 2^255 - 21 */ mult(out, t1, z11);
353 }
354 
355 /*
356  * Computes q = Curve25519(p, s)
357  */
358 SECStatus
ec_Curve25519_mul(PRUint8 * q,const PRUint8 * s,const PRUint8 * p)359 ec_Curve25519_mul(PRUint8 *q, const PRUint8 *s, const PRUint8 *p)
360 {
361     elem point = { 0 };
362     elem x_2 = { 0 };
363     elem z_2 = { 0 };
364     elem X = { 0 };
365     elem scalar = { 0 };
366     uint32_t i;
367 
368     /* read and mask scalar */
369     for (i = 0; i < 32; ++i) {
370         scalar[i] = s[i];
371     }
372     scalar[0] &= 0xF8;
373     scalar[31] &= 0x7F;
374     scalar[31] |= 64;
375 
376     /* read and mask point */
377     for (i = 0; i < 32; ++i) {
378         point[i] = p[i];
379     }
380     point[31] &= 0x7F;
381 
382     monty(x_2, z_2, point, scalar);
383     recip(z_2, z_2);
384     mult(X, x_2, z_2);
385     reduce(X);
386     for (i = 0; i < 32; ++i) {
387         q[i] = X[i];
388     }
389     return 0;
390 }
391