1 /*
2  * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4  *
5  * This code is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 only, as
7  * published by the Free Software Foundation.  Oracle designates this
8  * particular file as subject to the "Classpath" exception as provided
9  * by Oracle in the LICENSE file that accompanied this code.
10  *
11  * This code is distributed in the hope that it will be useful, but WITHOUT
12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
14  * version 2 for more details (a copy is included in the LICENSE file that
15  * accompanied this code).
16  *
17  * You should have received a copy of the GNU General Public License version
18  * 2 along with this work; if not, write to the Free Software Foundation,
19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20  *
21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22  * or visit www.oracle.com if you need additional information or have any
23  * questions.
24  */
25 
26 package sun.security.ec;
27 
28 import sun.security.ec.point.*;
29 import sun.security.util.math.*;
30 import sun.security.util.math.intpoly.*;
31 
32 import java.math.BigInteger;
33 import java.security.ProviderException;
34 import java.security.spec.ECFieldFp;
35 import java.security.spec.ECParameterSpec;
36 import java.security.spec.EllipticCurve;
37 import java.util.Collections;
38 import java.util.HashMap;
39 import java.util.Map;
40 import java.util.Optional;
41 
42 /*
43  * Elliptic curve point arithmetic for prime-order curves where a=-3.
44  * Formulas are derived from "Complete addition formulas for prime order
45  * elliptic curves" by Renes, Costello, and Batina.
46  */
47 
48 public class ECOperations {
49 
50     /*
51      * An exception indicating a problem with an intermediate value produced
52      * by some part of the computation. For example, the signing operation
53      * will throw this exception to indicate that the r or s value is 0, and
54      * that the signing operation should be tried again with a different nonce.
55      */
56     static class IntermediateValueException extends Exception {
57         private static final long serialVersionUID = 1;
58     }
59 
60     static final Map<BigInteger, IntegerFieldModuloP> fields;
61 
62     static final Map<BigInteger, IntegerFieldModuloP> orderFields;
63 
64     static {
65         Map<BigInteger, IntegerFieldModuloP> map = new HashMap<>();
map.put(IntegerPolynomialP256.MODULUS, new IntegerPolynomialP256())66         map.put(IntegerPolynomialP256.MODULUS, new IntegerPolynomialP256());
map.put(IntegerPolynomialP384.MODULUS, new IntegerPolynomialP384())67         map.put(IntegerPolynomialP384.MODULUS, new IntegerPolynomialP384());
map.put(IntegerPolynomialP521.MODULUS, new IntegerPolynomialP521())68         map.put(IntegerPolynomialP521.MODULUS, new IntegerPolynomialP521());
69         fields = Collections.unmodifiableMap(map);
70         map = new HashMap<>();
map.put(P256OrderField.MODULUS, new P256OrderField())71         map.put(P256OrderField.MODULUS, new P256OrderField());
map.put(P384OrderField.MODULUS, new P384OrderField())72         map.put(P384OrderField.MODULUS, new P384OrderField());
map.put(P521OrderField.MODULUS, new P521OrderField())73         map.put(P521OrderField.MODULUS, new P521OrderField());
74         orderFields = Collections.unmodifiableMap(map);
75     }
76 
forParameters(ECParameterSpec params)77     public static Optional<ECOperations> forParameters(ECParameterSpec params) {
78 
79         EllipticCurve curve = params.getCurve();
80         if (!(curve.getField() instanceof ECFieldFp)) {
81             return Optional.empty();
82         }
83         ECFieldFp primeField = (ECFieldFp) curve.getField();
84 
85         BigInteger three = BigInteger.valueOf(3);
86         if (!primeField.getP().subtract(curve.getA()).equals(three)) {
87             return Optional.empty();
88         }
89         IntegerFieldModuloP field = fields.get(primeField.getP());
90         if (field == null) {
91             return Optional.empty();
92         }
93 
94         IntegerFieldModuloP orderField = orderFields.get(params.getOrder());
95         if (orderField == null) {
96             return Optional.empty();
97         }
98 
99         ImmutableIntegerModuloP b = field.getElement(curve.getB());
100         ECOperations ecOps = new ECOperations(b, orderField);
101         return Optional.of(ecOps);
102     }
103 
104     final ImmutableIntegerModuloP b;
105     final SmallValue one;
106     final SmallValue two;
107     final SmallValue three;
108     final SmallValue four;
109     final ProjectivePoint.Immutable neutral;
110     private final IntegerFieldModuloP orderField;
111 
ECOperations(IntegerModuloP b, IntegerFieldModuloP orderField)112     public ECOperations(IntegerModuloP b, IntegerFieldModuloP orderField) {
113         this.b = b.fixed();
114         this.orderField = orderField;
115 
116         this.one = b.getField().getSmallValue(1);
117         this.two = b.getField().getSmallValue(2);
118         this.three = b.getField().getSmallValue(3);
119         this.four = b.getField().getSmallValue(4);
120 
121         IntegerFieldModuloP field = b.getField();
122         this.neutral = new ProjectivePoint.Immutable(field.get0(),
123             field.get1(), field.get0());
124     }
125 
getField()126     public IntegerFieldModuloP getField() {
127         return b.getField();
128     }
getOrderField()129     public IntegerFieldModuloP getOrderField() {
130         return orderField;
131     }
132 
getNeutral()133     protected ProjectivePoint.Immutable getNeutral() {
134         return neutral;
135     }
136 
isNeutral(Point p)137     public boolean isNeutral(Point p) {
138         ProjectivePoint<?> pp = (ProjectivePoint<?>) p;
139 
140         IntegerModuloP z = pp.getZ();
141 
142         IntegerFieldModuloP field = z.getField();
143         int byteLength = (field.getSize().bitLength() + 7) / 8;
144         byte[] zBytes = z.asByteArray(byteLength);
145         return allZero(zBytes);
146     }
147 
seedToScalar(byte[] seedBytes)148     byte[] seedToScalar(byte[] seedBytes)
149         throws IntermediateValueException {
150 
151         // Produce a nonce from the seed using FIPS 186-4,section B.5.1:
152         // Per-Message Secret Number Generation Using Extra Random Bits
153         // or
154         // Produce a scalar from the seed using FIPS 186-4, section B.4.1:
155         // Key Pair Generation Using Extra Random Bits
156 
157         // To keep the implementation simple, sample in the range [0,n)
158         // and throw IntermediateValueException in the (unlikely) event
159         // that the result is 0.
160 
161         // Get 64 extra bits and reduce in to the nonce
162         int seedBits = orderField.getSize().bitLength() + 64;
163         if (seedBytes.length * 8 < seedBits) {
164             throw new ProviderException("Incorrect seed length: " +
165             seedBytes.length * 8 + " < " + seedBits);
166         }
167 
168         // input conversion only works on byte boundaries
169         // clear high-order bits of last byte so they don't influence nonce
170         int lastByteBits = seedBits % 8;
171         if (lastByteBits != 0) {
172             int lastByteIndex = seedBits / 8;
173             byte mask = (byte) (0xFF >>> (8 - lastByteBits));
174             seedBytes[lastByteIndex] &= mask;
175         }
176 
177         int seedLength = (seedBits + 7) / 8;
178         IntegerModuloP scalarElem =
179             orderField.getElement(seedBytes, 0, seedLength, (byte) 0);
180         int scalarLength = (orderField.getSize().bitLength() + 7) / 8;
181         byte[] scalarArr = new byte[scalarLength];
182         scalarElem.asByteArray(scalarArr);
183         if (ECOperations.allZero(scalarArr)) {
184             throw new IntermediateValueException();
185         }
186         return scalarArr;
187     }
188 
189     /*
190      * Compare all values in the array to 0 without branching on any value
191      *
192      */
allZero(byte[] arr)193     public static boolean allZero(byte[] arr) {
194         byte acc = 0;
195         for (int i = 0; i < arr.length; i++) {
196             acc |= arr[i];
197         }
198         return acc == 0;
199     }
200 
201     /*
202      * 4-bit branchless array lookup for projective points.
203      */
lookup4(ProjectivePoint.Immutable[] arr, int index, ProjectivePoint.Mutable result, IntegerModuloP zero)204     private void lookup4(ProjectivePoint.Immutable[] arr, int index,
205         ProjectivePoint.Mutable result, IntegerModuloP zero) {
206 
207         for (int i = 0; i < 16; i++) {
208             int xor = index ^ i;
209             int bit3 = (xor & 0x8) >>> 3;
210             int bit2 = (xor & 0x4) >>> 2;
211             int bit1 = (xor & 0x2) >>> 1;
212             int bit0 = (xor & 0x1);
213             int inverse = bit0 | bit1 | bit2 | bit3;
214             int set = 1 - inverse;
215 
216             ProjectivePoint.Immutable pi = arr[i];
217             result.conditionalSet(pi, set);
218         }
219     }
220 
double4(ProjectivePoint.Mutable p, MutableIntegerModuloP t0, MutableIntegerModuloP t1, MutableIntegerModuloP t2, MutableIntegerModuloP t3, MutableIntegerModuloP t4)221     private void double4(ProjectivePoint.Mutable p, MutableIntegerModuloP t0,
222         MutableIntegerModuloP t1, MutableIntegerModuloP t2,
223         MutableIntegerModuloP t3, MutableIntegerModuloP t4) {
224 
225         for (int i = 0; i < 4; i++) {
226             setDouble(p, t0, t1, t2, t3, t4);
227         }
228     }
229 
230     /**
231      * Multiply an affine point by a scalar and return the result as a mutable
232      * point.
233      *
234      * @param affineP the point
235      * @param s the scalar as a little-endian array
236      * @return the product
237      */
multiply(AffinePoint affineP, byte[] s)238     public MutablePoint multiply(AffinePoint affineP, byte[] s) {
239 
240         // 4-bit windowed multiply with branchless lookup.
241         // The mixed addition is faster, so it is used to construct the array
242         // at the beginning of the operation.
243 
244         IntegerFieldModuloP field = affineP.getX().getField();
245         ImmutableIntegerModuloP zero = field.get0();
246         // temporaries
247         MutableIntegerModuloP t0 = zero.mutable();
248         MutableIntegerModuloP t1 = zero.mutable();
249         MutableIntegerModuloP t2 = zero.mutable();
250         MutableIntegerModuloP t3 = zero.mutable();
251         MutableIntegerModuloP t4 = zero.mutable();
252 
253         ProjectivePoint.Mutable result = new ProjectivePoint.Mutable(field);
254         result.getY().setValue(field.get1().mutable());
255 
256         ProjectivePoint.Immutable[] pointMultiples =
257             new ProjectivePoint.Immutable[16];
258         // 0P is neutral---same as initial result value
259         pointMultiples[0] = result.fixed();
260 
261         ProjectivePoint.Mutable ps = new ProjectivePoint.Mutable(field);
262         ps.setValue(affineP);
263         // 1P = P
264         pointMultiples[1] = ps.fixed();
265 
266         // the rest are calculated using mixed point addition
267         for (int i = 2; i < 16; i++) {
268             setSum(ps, affineP, t0, t1, t2, t3, t4);
269             pointMultiples[i] = ps.fixed();
270         }
271 
272         ProjectivePoint.Mutable lookupResult = ps.mutable();
273 
274         for (int i = s.length - 1; i >= 0; i--) {
275 
276             double4(result, t0, t1, t2, t3, t4);
277 
278             int high = (0xFF & s[i]) >>> 4;
279             lookup4(pointMultiples, high, lookupResult, zero);
280             setSum(result, lookupResult, t0, t1, t2, t3, t4);
281 
282             double4(result, t0, t1, t2, t3, t4);
283 
284             int low = 0xF & s[i];
285             lookup4(pointMultiples, low, lookupResult, zero);
286             setSum(result, lookupResult, t0, t1, t2, t3, t4);
287         }
288 
289         return result;
290 
291     }
292 
293     /*
294      * Point double
295      */
setDouble(ProjectivePoint.Mutable p, MutableIntegerModuloP t0, MutableIntegerModuloP t1, MutableIntegerModuloP t2, MutableIntegerModuloP t3, MutableIntegerModuloP t4)296     private void setDouble(ProjectivePoint.Mutable p, MutableIntegerModuloP t0,
297         MutableIntegerModuloP t1, MutableIntegerModuloP t2,
298         MutableIntegerModuloP t3, MutableIntegerModuloP t4) {
299 
300         t0.setValue(p.getX()).setSquare();
301         t1.setValue(p.getY()).setSquare();
302         t2.setValue(p.getZ()).setSquare();
303         t3.setValue(p.getX()).setProduct(p.getY());
304         t4.setValue(p.getY()).setProduct(p.getZ());
305 
306         t3.setSum(t3);
307         p.getZ().setProduct(p.getX());
308 
309         p.getZ().setProduct(two);
310 
311         p.getY().setValue(t2).setProduct(b);
312         p.getY().setDifference(p.getZ());
313 
314         p.getX().setValue(p.getY()).setProduct(two);
315         p.getY().setSum(p.getX());
316         p.getY().setReduced();
317         p.getX().setValue(t1).setDifference(p.getY());
318 
319         p.getY().setSum(t1);
320         p.getY().setProduct(p.getX());
321         p.getX().setProduct(t3);
322 
323         t3.setValue(t2).setProduct(two);
324         t2.setSum(t3);
325         p.getZ().setProduct(b);
326 
327         t2.setReduced();
328         p.getZ().setDifference(t2);
329         p.getZ().setDifference(t0);
330         t3.setValue(p.getZ()).setProduct(two);
331         p.getZ().setReduced();
332         p.getZ().setSum(t3);
333         t0.setProduct(three);
334 
335         t0.setDifference(t2);
336         t0.setProduct(p.getZ());
337         p.getY().setSum(t0);
338 
339         t4.setSum(t4);
340         p.getZ().setProduct(t4);
341 
342         p.getX().setDifference(p.getZ());
343         p.getZ().setValue(t4).setProduct(t1);
344 
345         p.getZ().setProduct(four);
346 
347     }
348 
349     /*
350      * Mixed point addition. This method constructs new temporaries each time
351      * it is called. For better efficiency, the method that reuses temporaries
352      * should be used if more than one sum will be computed.
353      */
setSum(MutablePoint p, AffinePoint p2)354     public void setSum(MutablePoint p, AffinePoint p2) {
355 
356         IntegerModuloP zero = p.getField().get0();
357         MutableIntegerModuloP t0 = zero.mutable();
358         MutableIntegerModuloP t1 = zero.mutable();
359         MutableIntegerModuloP t2 = zero.mutable();
360         MutableIntegerModuloP t3 = zero.mutable();
361         MutableIntegerModuloP t4 = zero.mutable();
362         setSum((ProjectivePoint.Mutable) p, p2, t0, t1, t2, t3, t4);
363 
364     }
365 
366     /*
367      * Mixed point addition
368      */
setSum(ProjectivePoint.Mutable p, AffinePoint p2, MutableIntegerModuloP t0, MutableIntegerModuloP t1, MutableIntegerModuloP t2, MutableIntegerModuloP t3, MutableIntegerModuloP t4)369     private void setSum(ProjectivePoint.Mutable p, AffinePoint p2,
370         MutableIntegerModuloP t0, MutableIntegerModuloP t1,
371         MutableIntegerModuloP t2, MutableIntegerModuloP t3,
372         MutableIntegerModuloP t4) {
373 
374         t0.setValue(p.getX()).setProduct(p2.getX());
375         t1.setValue(p.getY()).setProduct(p2.getY());
376         t3.setValue(p2.getX()).setSum(p2.getY());
377         t4.setValue(p.getX()).setSum(p.getY());
378         p.getX().setReduced();
379         t3.setProduct(t4);
380         t4.setValue(t0).setSum(t1);
381 
382         t3.setDifference(t4);
383         t4.setValue(p2.getY()).setProduct(p.getZ());
384         t4.setSum(p.getY());
385 
386         p.getY().setValue(p2.getX()).setProduct(p.getZ());
387         p.getY().setSum(p.getX());
388         t2.setValue(p.getZ());
389         p.getZ().setProduct(b);
390 
391         p.getX().setValue(p.getY()).setDifference(p.getZ());
392         p.getX().setReduced();
393         p.getZ().setValue(p.getX()).setProduct(two);
394         p.getX().setSum(p.getZ());
395 
396         p.getZ().setValue(t1).setDifference(p.getX());
397         p.getX().setSum(t1);
398         p.getY().setProduct(b);
399 
400         t1.setValue(t2).setProduct(two);
401         t2.setSum(t1);
402         t2.setReduced();
403         p.getY().setDifference(t2);
404 
405         p.getY().setDifference(t0);
406         p.getY().setReduced();
407         t1.setValue(p.getY()).setProduct(two);
408         p.getY().setSum(t1);
409 
410         t1.setValue(t0).setProduct(two);
411         t0.setSum(t1);
412         t0.setDifference(t2);
413 
414         t1.setValue(t4).setProduct(p.getY());
415         t2.setValue(t0).setProduct(p.getY());
416         p.getY().setValue(p.getX()).setProduct(p.getZ());
417 
418         p.getY().setSum(t2);
419         p.getX().setProduct(t3);
420         p.getX().setDifference(t1);
421 
422         p.getZ().setProduct(t4);
423         t1.setValue(t3).setProduct(t0);
424         p.getZ().setSum(t1);
425 
426     }
427 
428     /*
429      * Projective point addition
430      */
setSum(ProjectivePoint.Mutable p, ProjectivePoint.Mutable p2, MutableIntegerModuloP t0, MutableIntegerModuloP t1, MutableIntegerModuloP t2, MutableIntegerModuloP t3, MutableIntegerModuloP t4)431     private void setSum(ProjectivePoint.Mutable p, ProjectivePoint.Mutable p2,
432         MutableIntegerModuloP t0, MutableIntegerModuloP t1,
433         MutableIntegerModuloP t2, MutableIntegerModuloP t3,
434         MutableIntegerModuloP t4) {
435 
436         t0.setValue(p.getX()).setProduct(p2.getX());
437         t1.setValue(p.getY()).setProduct(p2.getY());
438         t2.setValue(p.getZ()).setProduct(p2.getZ());
439 
440         t3.setValue(p.getX()).setSum(p.getY());
441         t4.setValue(p2.getX()).setSum(p2.getY());
442         t3.setProduct(t4);
443 
444         t4.setValue(t0).setSum(t1);
445         t3.setDifference(t4);
446         t4.setValue(p.getY()).setSum(p.getZ());
447 
448         p.getY().setValue(p2.getY()).setSum(p2.getZ());
449         t4.setProduct(p.getY());
450         p.getY().setValue(t1).setSum(t2);
451 
452         t4.setDifference(p.getY());
453         p.getX().setSum(p.getZ());
454         p.getY().setValue(p2.getX()).setSum(p2.getZ());
455 
456         p.getX().setProduct(p.getY());
457         p.getY().setValue(t0).setSum(t2);
458         p.getY().setAdditiveInverse().setSum(p.getX());
459         p.getY().setReduced();
460 
461         p.getZ().setValue(t2).setProduct(b);
462         p.getX().setValue(p.getY()).setDifference(p.getZ());
463         p.getZ().setValue(p.getX()).setProduct(two);
464 
465         p.getX().setSum(p.getZ());
466         p.getX().setReduced();
467         p.getZ().setValue(t1).setDifference(p.getX());
468         p.getX().setSum(t1);
469 
470         p.getY().setProduct(b);
471         t1.setValue(t2).setSum(t2);
472         t2.setSum(t1);
473         t2.setReduced();
474 
475         p.getY().setDifference(t2);
476         p.getY().setDifference(t0);
477         p.getY().setReduced();
478         t1.setValue(p.getY()).setSum(p.getY());
479 
480         p.getY().setSum(t1);
481         t1.setValue(t0).setProduct(two);
482         t0.setSum(t1);
483 
484         t0.setDifference(t2);
485         t1.setValue(t4).setProduct(p.getY());
486         t2.setValue(t0).setProduct(p.getY());
487 
488         p.getY().setValue(p.getX()).setProduct(p.getZ());
489         p.getY().setSum(t2);
490         p.getX().setProduct(t3);
491 
492         p.getX().setDifference(t1);
493         p.getZ().setProduct(t4);
494         t1.setValue(t3).setProduct(t0);
495 
496         p.getZ().setSum(t1);
497 
498     }
499 }
500