1 /* This Source Code Form is subject to the terms of the Mozilla Public
2 * License, v. 2.0. If a copy of the MPL was not distributed with this
3 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
4
5 /*
6 * Derived from public domain code by Matthew Dempsky and D. J. Bernstein.
7 */
8
9 #include "ecl-priv.h"
10 #include "mpi.h"
11
12 #include <stdint.h>
13 #include <stdio.h>
14
15 typedef uint32_t elem[32];
16
17 /*
18 * Add two field elements.
19 * out = a + b
20 */
21 static void
add(elem out,const elem a,const elem b)22 add(elem out, const elem a, const elem b)
23 {
24 uint32_t j;
25 uint32_t u = 0;
26 for (j = 0; j < 31; ++j) {
27 u += a[j] + b[j];
28 out[j] = u & 0xFF;
29 u >>= 8;
30 }
31 u += a[31] + b[31];
32 out[31] = u;
33 }
34
35 /*
36 * Subtract two field elements.
37 * out = a - b
38 */
39 static void
sub(elem out,const elem a,const elem b)40 sub(elem out, const elem a, const elem b)
41 {
42 uint32_t j;
43 uint32_t u;
44 u = 218;
45 for (j = 0; j < 31; ++j) {
46 u += a[j] + 0xFF00 - b[j];
47 out[j] = u & 0xFF;
48 u >>= 8;
49 }
50 u += a[31] - b[31];
51 out[31] = u;
52 }
53
54 /*
55 * "Squeeze" an element after multiplication (and square).
56 */
57 static void
squeeze(elem a)58 squeeze(elem a)
59 {
60 uint32_t j;
61 uint32_t u;
62 u = 0;
63 for (j = 0; j < 31; ++j) {
64 u += a[j];
65 a[j] = u & 0xFF;
66 u >>= 8;
67 }
68 u += a[31];
69 a[31] = u & 0x7F;
70 u = 19 * (u >> 7);
71 for (j = 0; j < 31; ++j) {
72 u += a[j];
73 a[j] = u & 0xFF;
74 u >>= 8;
75 }
76 a[31] += u;
77 }
78
79 static const elem minusp = { 19, 0, 0, 0, 0, 0, 0, 0,
80 0, 0, 0, 0, 0, 0, 0, 0,
81 0, 0, 0, 0, 0, 0, 0, 0,
82 0, 0, 0, 0, 0, 0, 0, 128 };
83
84 /*
85 * Reduce point a by 2^255-19
86 */
87 static void
reduce(elem a)88 reduce(elem a)
89 {
90 elem aorig;
91 uint32_t j;
92 uint32_t negative;
93
94 for (j = 0; j < 32; ++j) {
95 aorig[j] = a[j];
96 }
97 add(a, a, minusp);
98 negative = 1 + ~((a[31] >> 7) & 1);
99 for (j = 0; j < 32; ++j) {
100 a[j] ^= negative & (aorig[j] ^ a[j]);
101 }
102 }
103
104 /*
105 * Multiplication and squeeze
106 * out = a * b
107 */
108 static void
mult(elem out,const elem a,const elem b)109 mult(elem out, const elem a, const elem b)
110 {
111 uint32_t i;
112 uint32_t j;
113 uint32_t u;
114
115 for (i = 0; i < 32; ++i) {
116 u = 0;
117 for (j = 0; j <= i; ++j) {
118 u += a[j] * b[i - j];
119 }
120 for (j = i + 1; j < 32; ++j) {
121 u += 38 * a[j] * b[i + 32 - j];
122 }
123 out[i] = u;
124 }
125 squeeze(out);
126 }
127
128 /*
129 * Multiplication
130 * out = 121665 * a
131 */
132 static void
mult121665(elem out,const elem a)133 mult121665(elem out, const elem a)
134 {
135 uint32_t j;
136 uint32_t u;
137
138 u = 0;
139 for (j = 0; j < 31; ++j) {
140 u += 121665 * a[j];
141 out[j] = u & 0xFF;
142 u >>= 8;
143 }
144 u += 121665 * a[31];
145 out[31] = u & 0x7F;
146 u = 19 * (u >> 7);
147 for (j = 0; j < 31; ++j) {
148 u += out[j];
149 out[j] = u & 0xFF;
150 u >>= 8;
151 }
152 u += out[j];
153 out[j] = u;
154 }
155
156 /*
157 * Square a and squeeze the result.
158 * out = a * a
159 */
160 static void
square(elem out,const elem a)161 square(elem out, const elem a)
162 {
163 uint32_t i;
164 uint32_t j;
165 uint32_t u;
166
167 for (i = 0; i < 32; ++i) {
168 u = 0;
169 for (j = 0; j < i - j; ++j) {
170 u += a[j] * a[i - j];
171 }
172 for (j = i + 1; j < i + 32 - j; ++j) {
173 u += 38 * a[j] * a[i + 32 - j];
174 }
175 u *= 2;
176 if ((i & 1) == 0) {
177 u += a[i / 2] * a[i / 2];
178 u += 38 * a[i / 2 + 16] * a[i / 2 + 16];
179 }
180 out[i] = u;
181 }
182 squeeze(out);
183 }
184
185 /*
186 * Constant time swap between r and s depending on b
187 */
188 static void
cswap(uint32_t p[64],uint32_t q[64],uint32_t b)189 cswap(uint32_t p[64], uint32_t q[64], uint32_t b)
190 {
191 uint32_t j;
192 uint32_t swap = 1 + ~b;
193
194 for (j = 0; j < 64; ++j) {
195 const uint32_t t = swap & (p[j] ^ q[j]);
196 p[j] ^= t;
197 q[j] ^= t;
198 }
199 }
200
201 /*
202 * Montgomery ladder
203 */
204 static void
monty(elem x_2_out,elem z_2_out,const elem point,const elem scalar)205 monty(elem x_2_out, elem z_2_out,
206 const elem point, const elem scalar)
207 {
208 uint32_t x_3[64] = { 0 };
209 uint32_t x_2[64] = { 0 };
210 uint32_t a0[64];
211 uint32_t a1[64];
212 uint32_t b0[64];
213 uint32_t b1[64];
214 uint32_t c1[64];
215 uint32_t r[32];
216 uint32_t s[32];
217 uint32_t t[32];
218 uint32_t u[32];
219 uint32_t swap = 0;
220 uint32_t k_t = 0;
221 int j;
222
223 for (j = 0; j < 32; ++j) {
224 x_3[j] = point[j];
225 }
226 x_3[32] = 1;
227 x_2[0] = 1;
228
229 for (j = 254; j >= 0; --j) {
230 k_t = (scalar[j >> 3] >> (j & 7)) & 1;
231 swap ^= k_t;
232 cswap(x_2, x_3, swap);
233 swap = k_t;
234 add(a0, x_2, x_2 + 32);
235 sub(a0 + 32, x_2, x_2 + 32);
236 add(a1, x_3, x_3 + 32);
237 sub(a1 + 32, x_3, x_3 + 32);
238 square(b0, a0);
239 square(b0 + 32, a0 + 32);
240 mult(b1, a1, a0 + 32);
241 mult(b1 + 32, a1 + 32, a0);
242 add(c1, b1, b1 + 32);
243 sub(c1 + 32, b1, b1 + 32);
244 square(r, c1 + 32);
245 sub(s, b0, b0 + 32);
246 mult121665(t, s);
247 add(u, t, b0);
248 mult(x_2, b0, b0 + 32);
249 mult(x_2 + 32, s, u);
250 square(x_3, c1);
251 mult(x_3 + 32, r, point);
252 }
253
254 cswap(x_2, x_3, swap);
255 for (j = 0; j < 32; ++j) {
256 x_2_out[j] = x_2[j];
257 }
258 for (j = 0; j < 32; ++j) {
259 z_2_out[j] = x_2[j + 32];
260 }
261 }
262
263 static void
recip(elem out,const elem z)264 recip(elem out, const elem z)
265 {
266 elem z2;
267 elem z9;
268 elem z11;
269 elem z2_5_0;
270 elem z2_10_0;
271 elem z2_20_0;
272 elem z2_50_0;
273 elem z2_100_0;
274 elem t0;
275 elem t1;
276 int i;
277
278 /* 2 */ square(z2, z);
279 /* 4 */ square(t1, z2);
280 /* 8 */ square(t0, t1);
281 /* 9 */ mult(z9, t0, z);
282 /* 11 */ mult(z11, z9, z2);
283 /* 22 */ square(t0, z11);
284 /* 2^5 - 2^0 = 31 */ mult(z2_5_0, t0, z9);
285
286 /* 2^6 - 2^1 */ square(t0, z2_5_0);
287 /* 2^7 - 2^2 */ square(t1, t0);
288 /* 2^8 - 2^3 */ square(t0, t1);
289 /* 2^9 - 2^4 */ square(t1, t0);
290 /* 2^10 - 2^5 */ square(t0, t1);
291 /* 2^10 - 2^0 */ mult(z2_10_0, t0, z2_5_0);
292
293 /* 2^11 - 2^1 */ square(t0, z2_10_0);
294 /* 2^12 - 2^2 */ square(t1, t0);
295 /* 2^20 - 2^10 */
296 for (i = 2; i < 10; i += 2) {
297 square(t0, t1);
298 square(t1, t0);
299 }
300 /* 2^20 - 2^0 */ mult(z2_20_0, t1, z2_10_0);
301
302 /* 2^21 - 2^1 */ square(t0, z2_20_0);
303 /* 2^22 - 2^2 */ square(t1, t0);
304 /* 2^40 - 2^20 */
305 for (i = 2; i < 20; i += 2) {
306 square(t0, t1);
307 square(t1, t0);
308 }
309 /* 2^40 - 2^0 */ mult(t0, t1, z2_20_0);
310
311 /* 2^41 - 2^1 */ square(t1, t0);
312 /* 2^42 - 2^2 */ square(t0, t1);
313 /* 2^50 - 2^10 */
314 for (i = 2; i < 10; i += 2) {
315 square(t1, t0);
316 square(t0, t1);
317 }
318 /* 2^50 - 2^0 */ mult(z2_50_0, t0, z2_10_0);
319
320 /* 2^51 - 2^1 */ square(t0, z2_50_0);
321 /* 2^52 - 2^2 */ square(t1, t0);
322 /* 2^100 - 2^50 */
323 for (i = 2; i < 50; i += 2) {
324 square(t0, t1);
325 square(t1, t0);
326 }
327 /* 2^100 - 2^0 */ mult(z2_100_0, t1, z2_50_0);
328
329 /* 2^101 - 2^1 */ square(t1, z2_100_0);
330 /* 2^102 - 2^2 */ square(t0, t1);
331 /* 2^200 - 2^100 */
332 for (i = 2; i < 100; i += 2) {
333 square(t1, t0);
334 square(t0, t1);
335 }
336 /* 2^200 - 2^0 */ mult(t1, t0, z2_100_0);
337
338 /* 2^201 - 2^1 */ square(t0, t1);
339 /* 2^202 - 2^2 */ square(t1, t0);
340 /* 2^250 - 2^50 */
341 for (i = 2; i < 50; i += 2) {
342 square(t0, t1);
343 square(t1, t0);
344 }
345 /* 2^250 - 2^0 */ mult(t0, t1, z2_50_0);
346
347 /* 2^251 - 2^1 */ square(t1, t0);
348 /* 2^252 - 2^2 */ square(t0, t1);
349 /* 2^253 - 2^3 */ square(t1, t0);
350 /* 2^254 - 2^4 */ square(t0, t1);
351 /* 2^255 - 2^5 */ square(t1, t0);
352 /* 2^255 - 21 */ mult(out, t1, z11);
353 }
354
355 /*
356 * Computes q = Curve25519(p, s)
357 */
358 SECStatus
ec_Curve25519_mul(PRUint8 * q,const PRUint8 * s,const PRUint8 * p)359 ec_Curve25519_mul(PRUint8 *q, const PRUint8 *s, const PRUint8 *p)
360 {
361 elem point = { 0 };
362 elem x_2 = { 0 };
363 elem z_2 = { 0 };
364 elem X = { 0 };
365 elem scalar = { 0 };
366 uint32_t i;
367
368 /* read and mask scalar */
369 for (i = 0; i < 32; ++i) {
370 scalar[i] = s[i];
371 }
372 scalar[0] &= 0xF8;
373 scalar[31] &= 0x7F;
374 scalar[31] |= 64;
375
376 /* read and mask point */
377 for (i = 0; i < 32; ++i) {
378 point[i] = p[i];
379 }
380 point[31] &= 0x7F;
381
382 monty(x_2, z_2, point, scalar);
383 recip(z_2, z_2);
384 mult(X, x_2, z_2);
385 reduce(X);
386 for (i = 0; i < 32; ++i) {
387 q[i] = X[i];
388 }
389 return 0;
390 }
391