1 /*
2  * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4  *
5  * This code is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 only, as
7  * published by the Free Software Foundation.  Oracle designates this
8  * particular file as subject to the "Classpath" exception as provided
9  * by Oracle in the LICENSE file that accompanied this code.
10  *
11  * This code is distributed in the hope that it will be useful, but WITHOUT
12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
14  * version 2 for more details (a copy is included in the LICENSE file that
15  * accompanied this code).
16  *
17  * You should have received a copy of the GNU General Public License version
18  * 2 along with this work; if not, write to the Free Software Foundation,
19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20  *
21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22  * or visit www.oracle.com if you need additional information or have any
23  * questions.
24  */
25 
26 package sun.security.ec;
27 
28 import sun.security.ec.point.*;
29 import sun.security.util.math.*;
30 import sun.security.util.math.intpoly.*;
31 
32 import java.math.BigInteger;
33 import java.security.ProviderException;
34 import java.security.spec.ECFieldFp;
35 import java.security.spec.ECParameterSpec;
36 import java.security.spec.EllipticCurve;
37 import java.util.Map;
38 import java.util.Optional;
39 
40 /*
41  * Elliptic curve point arithmetic for prime-order curves where a=-3.
42  * Formulas are derived from "Complete addition formulas for prime order
43  * elliptic curves" by Renes, Costello, and Batina.
44  */
45 
46 public class ECOperations {
47 
48     /*
49      * An exception indicating a problem with an intermediate value produced
50      * by some part of the computation. For example, the signing operation
51      * will throw this exception to indicate that the r or s value is 0, and
52      * that the signing operation should be tried again with a different nonce.
53      */
54     static class IntermediateValueException extends Exception {
55         private static final long serialVersionUID = 1;
56     }
57 
58     static final Map<BigInteger, IntegerFieldModuloP> fields = Map.of(
59         IntegerPolynomialP256.MODULUS, new IntegerPolynomialP256(),
60         IntegerPolynomialP384.MODULUS, new IntegerPolynomialP384(),
61         IntegerPolynomialP521.MODULUS, new IntegerPolynomialP521()
62     );
63 
64     static final Map<BigInteger, IntegerFieldModuloP> orderFields = Map.of(
65         P256OrderField.MODULUS, new P256OrderField(),
66         P384OrderField.MODULUS, new P384OrderField(),
67         P521OrderField.MODULUS, new P521OrderField()
68     );
69 
forParameters(ECParameterSpec params)70     public static Optional<ECOperations> forParameters(ECParameterSpec params) {
71 
72         EllipticCurve curve = params.getCurve();
73         if (!(curve.getField() instanceof ECFieldFp)) {
74             return Optional.empty();
75         }
76         ECFieldFp primeField = (ECFieldFp) curve.getField();
77 
78         BigInteger three = BigInteger.valueOf(3);
79         if (!primeField.getP().subtract(curve.getA()).equals(three)) {
80             return Optional.empty();
81         }
82         IntegerFieldModuloP field = fields.get(primeField.getP());
83         if (field == null) {
84             return Optional.empty();
85         }
86 
87         IntegerFieldModuloP orderField = orderFields.get(params.getOrder());
88         if (orderField == null) {
89             return Optional.empty();
90         }
91 
92         ImmutableIntegerModuloP b = field.getElement(curve.getB());
93         ECOperations ecOps = new ECOperations(b, orderField);
94         return Optional.of(ecOps);
95     }
96 
97     final ImmutableIntegerModuloP b;
98     final SmallValue one;
99     final SmallValue two;
100     final SmallValue three;
101     final SmallValue four;
102     final ProjectivePoint.Immutable neutral;
103     private final IntegerFieldModuloP orderField;
104 
ECOperations(IntegerModuloP b, IntegerFieldModuloP orderField)105     public ECOperations(IntegerModuloP b, IntegerFieldModuloP orderField) {
106         this.b = b.fixed();
107         this.orderField = orderField;
108 
109         this.one = b.getField().getSmallValue(1);
110         this.two = b.getField().getSmallValue(2);
111         this.three = b.getField().getSmallValue(3);
112         this.four = b.getField().getSmallValue(4);
113 
114         IntegerFieldModuloP field = b.getField();
115         this.neutral = new ProjectivePoint.Immutable(field.get0(),
116             field.get1(), field.get0());
117     }
118 
getField()119     public IntegerFieldModuloP getField() {
120         return b.getField();
121     }
getOrderField()122     public IntegerFieldModuloP getOrderField() {
123         return orderField;
124     }
125 
getNeutral()126     protected ProjectivePoint.Immutable getNeutral() {
127         return neutral;
128     }
129 
isNeutral(Point p)130     public boolean isNeutral(Point p) {
131         ProjectivePoint<?> pp = (ProjectivePoint<?>) p;
132 
133         IntegerModuloP z = pp.getZ();
134 
135         IntegerFieldModuloP field = z.getField();
136         int byteLength = (field.getSize().bitLength() + 7) / 8;
137         byte[] zBytes = z.asByteArray(byteLength);
138         return allZero(zBytes);
139     }
140 
seedToScalar(byte[] seedBytes)141     byte[] seedToScalar(byte[] seedBytes)
142         throws IntermediateValueException {
143 
144         // Produce a nonce from the seed using FIPS 186-4,section B.5.1:
145         // Per-Message Secret Number Generation Using Extra Random Bits
146         // or
147         // Produce a scalar from the seed using FIPS 186-4, section B.4.1:
148         // Key Pair Generation Using Extra Random Bits
149 
150         // To keep the implementation simple, sample in the range [0,n)
151         // and throw IntermediateValueException in the (unlikely) event
152         // that the result is 0.
153 
154         // Get 64 extra bits and reduce in to the nonce
155         int seedBits = orderField.getSize().bitLength() + 64;
156         if (seedBytes.length * 8 < seedBits) {
157             throw new ProviderException("Incorrect seed length: " +
158             seedBytes.length * 8 + " < " + seedBits);
159         }
160 
161         // input conversion only works on byte boundaries
162         // clear high-order bits of last byte so they don't influence nonce
163         int lastByteBits = seedBits % 8;
164         if (lastByteBits != 0) {
165             int lastByteIndex = seedBits / 8;
166             byte mask = (byte) (0xFF >>> (8 - lastByteBits));
167             seedBytes[lastByteIndex] &= mask;
168         }
169 
170         int seedLength = (seedBits + 7) / 8;
171         IntegerModuloP scalarElem =
172             orderField.getElement(seedBytes, 0, seedLength, (byte) 0);
173         int scalarLength = (orderField.getSize().bitLength() + 7) / 8;
174         byte[] scalarArr = new byte[scalarLength];
175         scalarElem.asByteArray(scalarArr);
176         if (ECOperations.allZero(scalarArr)) {
177             throw new IntermediateValueException();
178         }
179         return scalarArr;
180     }
181 
182     /*
183      * Compare all values in the array to 0 without branching on any value
184      *
185      */
allZero(byte[] arr)186     public static boolean allZero(byte[] arr) {
187         byte acc = 0;
188         for (int i = 0; i < arr.length; i++) {
189             acc |= arr[i];
190         }
191         return acc == 0;
192     }
193 
194     /*
195      * 4-bit branchless array lookup for projective points.
196      */
lookup4(ProjectivePoint.Immutable[] arr, int index, ProjectivePoint.Mutable result, IntegerModuloP zero)197     private void lookup4(ProjectivePoint.Immutable[] arr, int index,
198         ProjectivePoint.Mutable result, IntegerModuloP zero) {
199 
200         for (int i = 0; i < 16; i++) {
201             int xor = index ^ i;
202             int bit3 = (xor & 0x8) >>> 3;
203             int bit2 = (xor & 0x4) >>> 2;
204             int bit1 = (xor & 0x2) >>> 1;
205             int bit0 = (xor & 0x1);
206             int inverse = bit0 | bit1 | bit2 | bit3;
207             int set = 1 - inverse;
208 
209             ProjectivePoint.Immutable pi = arr[i];
210             result.conditionalSet(pi, set);
211         }
212     }
213 
double4(ProjectivePoint.Mutable p, MutableIntegerModuloP t0, MutableIntegerModuloP t1, MutableIntegerModuloP t2, MutableIntegerModuloP t3, MutableIntegerModuloP t4)214     private void double4(ProjectivePoint.Mutable p, MutableIntegerModuloP t0,
215         MutableIntegerModuloP t1, MutableIntegerModuloP t2,
216         MutableIntegerModuloP t3, MutableIntegerModuloP t4) {
217 
218         for (int i = 0; i < 4; i++) {
219             setDouble(p, t0, t1, t2, t3, t4);
220         }
221     }
222 
223     /**
224      * Multiply an affine point by a scalar and return the result as a mutable
225      * point.
226      *
227      * @param affineP the point
228      * @param s the scalar as a little-endian array
229      * @return the product
230      */
multiply(AffinePoint affineP, byte[] s)231     public MutablePoint multiply(AffinePoint affineP, byte[] s) {
232 
233         // 4-bit windowed multiply with branchless lookup.
234         // The mixed addition is faster, so it is used to construct the array
235         // at the beginning of the operation.
236 
237         IntegerFieldModuloP field = affineP.getX().getField();
238         ImmutableIntegerModuloP zero = field.get0();
239         // temporaries
240         MutableIntegerModuloP t0 = zero.mutable();
241         MutableIntegerModuloP t1 = zero.mutable();
242         MutableIntegerModuloP t2 = zero.mutable();
243         MutableIntegerModuloP t3 = zero.mutable();
244         MutableIntegerModuloP t4 = zero.mutable();
245 
246         ProjectivePoint.Mutable result = new ProjectivePoint.Mutable(field);
247         result.getY().setValue(field.get1().mutable());
248 
249         ProjectivePoint.Immutable[] pointMultiples =
250             new ProjectivePoint.Immutable[16];
251         // 0P is neutral---same as initial result value
252         pointMultiples[0] = result.fixed();
253 
254         ProjectivePoint.Mutable ps = new ProjectivePoint.Mutable(field);
255         ps.setValue(affineP);
256         // 1P = P
257         pointMultiples[1] = ps.fixed();
258 
259         // the rest are calculated using mixed point addition
260         for (int i = 2; i < 16; i++) {
261             setSum(ps, affineP, t0, t1, t2, t3, t4);
262             pointMultiples[i] = ps.fixed();
263         }
264 
265         ProjectivePoint.Mutable lookupResult = ps.mutable();
266 
267         for (int i = s.length - 1; i >= 0; i--) {
268 
269             double4(result, t0, t1, t2, t3, t4);
270 
271             int high = (0xFF & s[i]) >>> 4;
272             lookup4(pointMultiples, high, lookupResult, zero);
273             setSum(result, lookupResult, t0, t1, t2, t3, t4);
274 
275             double4(result, t0, t1, t2, t3, t4);
276 
277             int low = 0xF & s[i];
278             lookup4(pointMultiples, low, lookupResult, zero);
279             setSum(result, lookupResult, t0, t1, t2, t3, t4);
280         }
281 
282         return result;
283 
284     }
285 
286     /*
287      * Point double
288      */
setDouble(ProjectivePoint.Mutable p, MutableIntegerModuloP t0, MutableIntegerModuloP t1, MutableIntegerModuloP t2, MutableIntegerModuloP t3, MutableIntegerModuloP t4)289     private void setDouble(ProjectivePoint.Mutable p, MutableIntegerModuloP t0,
290         MutableIntegerModuloP t1, MutableIntegerModuloP t2,
291         MutableIntegerModuloP t3, MutableIntegerModuloP t4) {
292 
293         t0.setValue(p.getX()).setSquare();
294         t1.setValue(p.getY()).setSquare();
295         t2.setValue(p.getZ()).setSquare();
296         t3.setValue(p.getX()).setProduct(p.getY());
297         t4.setValue(p.getY()).setProduct(p.getZ());
298 
299         t3.setSum(t3);
300         p.getZ().setProduct(p.getX());
301 
302         p.getZ().setProduct(two);
303 
304         p.getY().setValue(t2).setProduct(b);
305         p.getY().setDifference(p.getZ());
306 
307         p.getX().setValue(p.getY()).setProduct(two);
308         p.getY().setSum(p.getX());
309         p.getY().setReduced();
310         p.getX().setValue(t1).setDifference(p.getY());
311 
312         p.getY().setSum(t1);
313         p.getY().setProduct(p.getX());
314         p.getX().setProduct(t3);
315 
316         t3.setValue(t2).setProduct(two);
317         t2.setSum(t3);
318         p.getZ().setProduct(b);
319 
320         t2.setReduced();
321         p.getZ().setDifference(t2);
322         p.getZ().setDifference(t0);
323         t3.setValue(p.getZ()).setProduct(two);
324         p.getZ().setReduced();
325         p.getZ().setSum(t3);
326         t0.setProduct(three);
327 
328         t0.setDifference(t2);
329         t0.setProduct(p.getZ());
330         p.getY().setSum(t0);
331 
332         t4.setSum(t4);
333         p.getZ().setProduct(t4);
334 
335         p.getX().setDifference(p.getZ());
336         p.getZ().setValue(t4).setProduct(t1);
337 
338         p.getZ().setProduct(four);
339 
340     }
341 
342     /*
343      * Mixed point addition. This method constructs new temporaries each time
344      * it is called. For better efficiency, the method that reuses temporaries
345      * should be used if more than one sum will be computed.
346      */
setSum(MutablePoint p, AffinePoint p2)347     public void setSum(MutablePoint p, AffinePoint p2) {
348 
349         IntegerModuloP zero = p.getField().get0();
350         MutableIntegerModuloP t0 = zero.mutable();
351         MutableIntegerModuloP t1 = zero.mutable();
352         MutableIntegerModuloP t2 = zero.mutable();
353         MutableIntegerModuloP t3 = zero.mutable();
354         MutableIntegerModuloP t4 = zero.mutable();
355         setSum((ProjectivePoint.Mutable) p, p2, t0, t1, t2, t3, t4);
356 
357     }
358 
359     /*
360      * Mixed point addition
361      */
setSum(ProjectivePoint.Mutable p, AffinePoint p2, MutableIntegerModuloP t0, MutableIntegerModuloP t1, MutableIntegerModuloP t2, MutableIntegerModuloP t3, MutableIntegerModuloP t4)362     private void setSum(ProjectivePoint.Mutable p, AffinePoint p2,
363         MutableIntegerModuloP t0, MutableIntegerModuloP t1,
364         MutableIntegerModuloP t2, MutableIntegerModuloP t3,
365         MutableIntegerModuloP t4) {
366 
367         t0.setValue(p.getX()).setProduct(p2.getX());
368         t1.setValue(p.getY()).setProduct(p2.getY());
369         t3.setValue(p2.getX()).setSum(p2.getY());
370         t4.setValue(p.getX()).setSum(p.getY());
371         p.getX().setReduced();
372         t3.setProduct(t4);
373         t4.setValue(t0).setSum(t1);
374 
375         t3.setDifference(t4);
376         t4.setValue(p2.getY()).setProduct(p.getZ());
377         t4.setSum(p.getY());
378 
379         p.getY().setValue(p2.getX()).setProduct(p.getZ());
380         p.getY().setSum(p.getX());
381         t2.setValue(p.getZ());
382         p.getZ().setProduct(b);
383 
384         p.getX().setValue(p.getY()).setDifference(p.getZ());
385         p.getX().setReduced();
386         p.getZ().setValue(p.getX()).setProduct(two);
387         p.getX().setSum(p.getZ());
388 
389         p.getZ().setValue(t1).setDifference(p.getX());
390         p.getX().setSum(t1);
391         p.getY().setProduct(b);
392 
393         t1.setValue(t2).setProduct(two);
394         t2.setSum(t1);
395         t2.setReduced();
396         p.getY().setDifference(t2);
397 
398         p.getY().setDifference(t0);
399         p.getY().setReduced();
400         t1.setValue(p.getY()).setProduct(two);
401         p.getY().setSum(t1);
402 
403         t1.setValue(t0).setProduct(two);
404         t0.setSum(t1);
405         t0.setDifference(t2);
406 
407         t1.setValue(t4).setProduct(p.getY());
408         t2.setValue(t0).setProduct(p.getY());
409         p.getY().setValue(p.getX()).setProduct(p.getZ());
410 
411         p.getY().setSum(t2);
412         p.getX().setProduct(t3);
413         p.getX().setDifference(t1);
414 
415         p.getZ().setProduct(t4);
416         t1.setValue(t3).setProduct(t0);
417         p.getZ().setSum(t1);
418 
419     }
420 
421     /*
422      * Projective point addition
423      */
setSum(ProjectivePoint.Mutable p, ProjectivePoint.Mutable p2, MutableIntegerModuloP t0, MutableIntegerModuloP t1, MutableIntegerModuloP t2, MutableIntegerModuloP t3, MutableIntegerModuloP t4)424     private void setSum(ProjectivePoint.Mutable p, ProjectivePoint.Mutable p2,
425         MutableIntegerModuloP t0, MutableIntegerModuloP t1,
426         MutableIntegerModuloP t2, MutableIntegerModuloP t3,
427         MutableIntegerModuloP t4) {
428 
429         t0.setValue(p.getX()).setProduct(p2.getX());
430         t1.setValue(p.getY()).setProduct(p2.getY());
431         t2.setValue(p.getZ()).setProduct(p2.getZ());
432 
433         t3.setValue(p.getX()).setSum(p.getY());
434         t4.setValue(p2.getX()).setSum(p2.getY());
435         t3.setProduct(t4);
436 
437         t4.setValue(t0).setSum(t1);
438         t3.setDifference(t4);
439         t4.setValue(p.getY()).setSum(p.getZ());
440 
441         p.getY().setValue(p2.getY()).setSum(p2.getZ());
442         t4.setProduct(p.getY());
443         p.getY().setValue(t1).setSum(t2);
444 
445         t4.setDifference(p.getY());
446         p.getX().setSum(p.getZ());
447         p.getY().setValue(p2.getX()).setSum(p2.getZ());
448 
449         p.getX().setProduct(p.getY());
450         p.getY().setValue(t0).setSum(t2);
451         p.getY().setAdditiveInverse().setSum(p.getX());
452         p.getY().setReduced();
453 
454         p.getZ().setValue(t2).setProduct(b);
455         p.getX().setValue(p.getY()).setDifference(p.getZ());
456         p.getZ().setValue(p.getX()).setProduct(two);
457 
458         p.getX().setSum(p.getZ());
459         p.getX().setReduced();
460         p.getZ().setValue(t1).setDifference(p.getX());
461         p.getX().setSum(t1);
462 
463         p.getY().setProduct(b);
464         t1.setValue(t2).setSum(t2);
465         t2.setSum(t1);
466         t2.setReduced();
467 
468         p.getY().setDifference(t2);
469         p.getY().setDifference(t0);
470         p.getY().setReduced();
471         t1.setValue(p.getY()).setSum(p.getY());
472 
473         p.getY().setSum(t1);
474         t1.setValue(t0).setProduct(two);
475         t0.setSum(t1);
476 
477         t0.setDifference(t2);
478         t1.setValue(t4).setProduct(p.getY());
479         t2.setValue(t0).setProduct(p.getY());
480 
481         p.getY().setValue(p.getX()).setProduct(p.getZ());
482         p.getY().setSum(t2);
483         p.getX().setProduct(t3);
484 
485         p.getX().setDifference(t1);
486         p.getZ().setProduct(t4);
487         t1.setValue(t3).setProduct(t0);
488 
489         p.getZ().setSum(t1);
490 
491     }
492 }
493 
494