1 /*
2  * Copyright (c) 2018, 2021, Oracle and/or its affiliates. All rights reserved.
3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4  *
5  * This code is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 only, as
7  * published by the Free Software Foundation.
8  *
9  * This code is distributed in the hope that it will be useful, but WITHOUT
10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12  * version 2 for more details (a copy is included in the LICENSE file that
13  * accompanied this code).
14  *
15  * You should have received a copy of the GNU General Public License version
16  * 2 along with this work; if not, write to the Free Software Foundation,
17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18  *
19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20  * or visit www.oracle.com if you need additional information or have any
21  * questions.
22  */
23 
24 /*
25  * @test
26  * @bug 8181594 8208648
27  * @summary Test proper operation of integer field arithmetic
28  * @modules java.base/sun.security.util java.base/sun.security.util.math java.base/sun.security.util.math.intpoly
29  * @build BigIntegerModuloP
30  * @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomial25519 32 0
31  * @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomial448 56 1
32  * @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomial1305 16 2
33  * @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomialP256 32 5
34  * @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomialP384 48 6
35  * @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomialP521 66 7
36  * @run main TestIntegerModuloP sun.security.util.math.intpoly.P256OrderField 32 8
37  * @run main TestIntegerModuloP sun.security.util.math.intpoly.P384OrderField 48 9
38  * @run main TestIntegerModuloP sun.security.util.math.intpoly.P521OrderField 66 10
39  * @run main TestIntegerModuloP sun.security.util.math.intpoly.Curve25519OrderField 32 11
40  * @run main TestIntegerModuloP sun.security.util.math.intpoly.Curve448OrderField 56 12
41  */
42 
43 import sun.security.util.math.*;
44 import sun.security.util.math.intpoly.*;
45 import java.util.function.*;
46 
47 import java.util.*;
48 import java.math.*;
49 import java.nio.*;
50 
51 public class TestIntegerModuloP {
52 
53     static BigInteger TWO = BigInteger.valueOf(2);
54 
55     // The test has a list of functions, and it selects randomly from that list
56 
57     // The function types
58     interface ElemFunction extends BiFunction
59         <MutableIntegerModuloP, IntegerModuloP, IntegerModuloP> { }
60     interface ElemArrayFunction extends BiFunction
61         <MutableIntegerModuloP, IntegerModuloP, byte[]> { }
62     interface TriConsumer <T, U, V> {
accept(T t, U u, V v)63         void accept(T t, U u, V v);
64     }
65     interface ElemSetFunction extends TriConsumer
66         <MutableIntegerModuloP, IntegerModuloP, byte[]> { }
67 
68     // The lists of functions. Multiple lists are needed because the test
69     // respects the limitations of the arithmetic implementations.
70     static final List<ElemFunction> ADD_FUNCTIONS = new ArrayList<>();
71     static final List<ElemFunction> MULT_FUNCTIONS = new ArrayList<>();
72     static final List<ElemArrayFunction> ARRAY_FUNCTIONS = new ArrayList<>();
73     static final List<ElemSetFunction> SET_FUNCTIONS = new ArrayList<>();
74 
setUpFunctions(IntegerFieldModuloP field, int length)75     static void setUpFunctions(IntegerFieldModuloP field, int length) {
76 
77         ADD_FUNCTIONS.clear();
78         MULT_FUNCTIONS.clear();
79         SET_FUNCTIONS.clear();
80         ARRAY_FUNCTIONS.clear();
81 
82         byte highByte = (byte)
83             (field.getSize().bitLength() > length * 8 ? 1 : 0);
84 
85         // add functions are (im)mutable add/subtract
86         ADD_FUNCTIONS.add(IntegerModuloP::add);
87         ADD_FUNCTIONS.add(IntegerModuloP::subtract);
88         ADD_FUNCTIONS.add(MutableIntegerModuloP::setSum);
89         ADD_FUNCTIONS.add(MutableIntegerModuloP::setDifference);
90         // also include functions that return the first/second argument
91         ADD_FUNCTIONS.add((a, b) -> a);
92         ADD_FUNCTIONS.add((a, b) -> b);
93 
94         // mult functions are (im)mutable multiply and square
95         MULT_FUNCTIONS.add(IntegerModuloP::multiply);
96         MULT_FUNCTIONS.add((a, b) -> a.square());
97         MULT_FUNCTIONS.add((a, b) -> b.square());
98         MULT_FUNCTIONS.add(MutableIntegerModuloP::setProduct);
99         MULT_FUNCTIONS.add((a, b) -> a.setSquare());
100         // also test multiplication by a small value
101         MULT_FUNCTIONS.add((a, b) -> a.setProduct(b.getField().getSmallValue(
102             b.asBigInteger().mod(BigInteger.valueOf(262144)).intValue())));
103 
104         // set functions are setValue with various argument types
105         SET_FUNCTIONS.add((a, b, c) -> a.setValue(b));
106         SET_FUNCTIONS.add((a, b, c) ->
107             a.setValue(c, 0, c.length, (byte) 0));
108         SET_FUNCTIONS.add((a, b, c) ->
109             a.setValue(c, 0, c.length / 2, (byte) 0));
110         SET_FUNCTIONS.add((a, b, c) ->
111             a.setValue(ByteBuffer.wrap(c, 0, c.length / 2).order(ByteOrder.LITTLE_ENDIAN),
112             c.length / 2, highByte));
113 
114         // array functions return the (possibly modified) value as byte array
115         ARRAY_FUNCTIONS.add((a, b ) -> a.asByteArray(length));
116         ARRAY_FUNCTIONS.add((a, b) -> a.addModPowerTwo(b, length));
117     }
118 
main(String[] args)119     public static void main(String[] args) {
120 
121         String className = args[0];
122         final int length = Integer.parseInt(args[1]);
123         int seed = Integer.parseInt(args[2]);
124 
125         Class<IntegerFieldModuloP> fieldBaseClass = IntegerFieldModuloP.class;
126         try {
127             Class<? extends IntegerFieldModuloP> clazz =
128                 Class.forName(className).asSubclass(fieldBaseClass);
129             IntegerFieldModuloP field =
130                 clazz.getDeclaredConstructor().newInstance();
131 
132             setUpFunctions(field, length);
133 
134             runFieldTest(field, length, seed);
135         } catch (Exception ex) {
136             throw new RuntimeException(ex);
137         }
138         System.out.println("All tests passed");
139     }
140 
assertEqual(IntegerModuloP e1, IntegerModuloP e2)141     static void assertEqual(IntegerModuloP e1, IntegerModuloP e2) {
142 
143         if (!e1.asBigInteger().equals(e2.asBigInteger())) {
144             throw new RuntimeException("values not equal: "
145                 + e1.asBigInteger() + " != " + e2.asBigInteger());
146         }
147     }
148 
149     // A class that holds pairs of actual/expected values, and allows
150     // computation on these pairs.
151     static class TestPair<T extends IntegerModuloP> {
152         private final T test;
153         private final T baseline;
154 
TestPair(T test, T baseline)155         public TestPair(T test, T baseline) {
156             this.test = test;
157             this.baseline = baseline;
158         }
159 
getTest()160         public T getTest() {
161             return test;
162         }
getBaseline()163         public T getBaseline() {
164             return baseline;
165         }
166 
assertEqual()167         private void assertEqual() {
168             TestIntegerModuloP.assertEqual(test, baseline);
169         }
170 
mutable()171         public TestPair<MutableIntegerModuloP> mutable() {
172             return new TestPair<>(test.mutable(), baseline.mutable());
173         }
174 
175         public
176         <R extends IntegerModuloP, X extends IntegerModuloP>
apply(BiFunction<T, R, X> func, TestPair<R> right)177         TestPair<X> apply(BiFunction<T, R, X> func, TestPair<R> right) {
178             X testResult = func.apply(test, right.test);
179             X baselineResult = func.apply(baseline, right.baseline);
180             return new TestPair(testResult, baselineResult);
181         }
182 
183         public
184         <U extends IntegerModuloP, V>
apply(TriConsumer<T, U, V> func, TestPair<U> right, V argV)185         void apply(TriConsumer<T, U, V> func, TestPair<U> right, V argV) {
186             func.accept(test, right.test, argV);
187             func.accept(baseline, right.baseline, argV);
188         }
189 
190         public
191         <R extends IntegerModuloP>
applyAndCheckArray(BiFunction<T, R, byte[]> func, TestPair<R> right)192         void applyAndCheckArray(BiFunction<T, R, byte[]> func,
193                                 TestPair<R> right) {
194             byte[] testResult = func.apply(test, right.test);
195             byte[] baselineResult = func.apply(baseline, right.baseline);
196             if (!Arrays.equals(testResult, baselineResult)) {
197                 throw new RuntimeException("Array values do not match: "
198                     + HexFormat.of().withUpperCase().formatHex(testResult) + " != "
199                     + HexFormat.of().withUpperCase().formatHex(baselineResult));
200             }
201         }
202 
203     }
204 
205     static TestPair<IntegerModuloP>
applyAndCheck(ElemFunction func, TestPair<MutableIntegerModuloP> left, TestPair<IntegerModuloP> right)206     applyAndCheck(ElemFunction func, TestPair<MutableIntegerModuloP> left,
207                   TestPair<IntegerModuloP> right) {
208 
209         TestPair<IntegerModuloP> result = left.apply(func, right);
210         result.assertEqual();
211         left.assertEqual();
212         right.assertEqual();
213 
214         return result;
215     }
216 
217     static void
setAndCheck(ElemSetFunction func, TestPair<MutableIntegerModuloP> left, TestPair<IntegerModuloP> right, byte[] argV)218     setAndCheck(ElemSetFunction func, TestPair<MutableIntegerModuloP> left,
219                 TestPair<IntegerModuloP> right, byte[] argV) {
220 
221         left.apply(func, right, argV);
222         left.assertEqual();
223         right.assertEqual();
224     }
225 
226     static TestPair<MutableIntegerModuloP>
applyAndCheckMutable(ElemFunction func, TestPair<MutableIntegerModuloP> left, TestPair<IntegerModuloP> right)227     applyAndCheckMutable(ElemFunction func,
228                          TestPair<MutableIntegerModuloP> left,
229                          TestPair<IntegerModuloP> right) {
230 
231         TestPair<IntegerModuloP> result = applyAndCheck(func, left, right);
232 
233         TestPair<MutableIntegerModuloP> mutableResult = result.mutable();
234         mutableResult.assertEqual();
235         result.assertEqual();
236         left.assertEqual();
237         right.assertEqual();
238 
239         return mutableResult;
240     }
241 
242     static void
cswapAndCheck(int swap, TestPair<MutableIntegerModuloP> left, TestPair<MutableIntegerModuloP> right)243     cswapAndCheck(int swap, TestPair<MutableIntegerModuloP> left,
244                   TestPair<MutableIntegerModuloP> right) {
245 
246         left.getTest().conditionalSwapWith(right.getTest(), swap);
247         left.getBaseline().conditionalSwapWith(right.getBaseline(), swap);
248 
249         left.assertEqual();
250         right.assertEqual();
251 
252     }
253 
254     // Request arithmetic that should overflow, and ensure that overflow is
255     // detected.
runOverflowTest(TestPair<IntegerModuloP> elem)256     static void runOverflowTest(TestPair<IntegerModuloP> elem) {
257 
258         TestPair<MutableIntegerModuloP> mutableElem = elem.mutable();
259 
260         try {
261             for (int i = 0; i < 1000; i++) {
262                 applyAndCheck(MutableIntegerModuloP::setSum, mutableElem, elem);
263             }
264             applyAndCheck(MutableIntegerModuloP::setProduct, mutableElem, elem);
265         } catch (ArithmeticException ex) {
266             // this is expected
267         }
268 
269         mutableElem = elem.mutable();
270         try {
271             for (int i = 0; i < 1000; i++) {
272                 elem = applyAndCheck(IntegerModuloP::add,
273                     mutableElem, elem);
274             }
275             applyAndCheck(IntegerModuloP::multiply, mutableElem, elem);
276         } catch (ArithmeticException ex) {
277             // this is expected
278         }
279     }
280 
281     // Run a large number of random operations and ensure that
282     // results are correct
runOperationsTest(Random random, int length, TestPair<IntegerModuloP> elem, TestPair<IntegerModuloP> right)283     static void runOperationsTest(Random random, int length,
284                                   TestPair<IntegerModuloP> elem,
285                                   TestPair<IntegerModuloP> right) {
286 
287         TestPair<MutableIntegerModuloP> left = elem.mutable();
288 
289         for (int i = 0; i < 10000; i++) {
290 
291             ElemFunction addFunc1 =
292                 ADD_FUNCTIONS.get(random.nextInt(ADD_FUNCTIONS.size()));
293             TestPair<MutableIntegerModuloP> result1 =
294                 applyAndCheckMutable(addFunc1, left, right);
295 
296             // left could have been modified, so turn it back into a summand
297             applyAndCheckMutable((a, b) -> a.setSquare(), left, right);
298 
299             ElemFunction addFunc2 =
300                 ADD_FUNCTIONS.get(random.nextInt(ADD_FUNCTIONS.size()));
301             TestPair<IntegerModuloP> result2 =
302                 applyAndCheck(addFunc2, left, right);
303 
304             if (elem.test.getField() instanceof IntegerPolynomial) {
305                 IntegerPolynomial field =
306                     (IntegerPolynomial) elem.test.getField();
307                 int numAdds = field.getMaxAdds();
308                 for (int j = 1; j < numAdds; j++) {
309                     ElemFunction addFunc3 = ADD_FUNCTIONS.
310                         get(random.nextInt(ADD_FUNCTIONS.size()));
311                     result2 = applyAndCheck(addFunc3, left, right);
312                 }
313             }
314 
315             ElemFunction multFunc2 =
316                 MULT_FUNCTIONS.get(random.nextInt(MULT_FUNCTIONS.size()));
317             TestPair<MutableIntegerModuloP> multResult =
318                 applyAndCheckMutable(multFunc2, result1, result2);
319 
320             int swap = random.nextInt(2);
321             cswapAndCheck(swap, left, multResult);
322 
323             ElemSetFunction setFunc =
324                 SET_FUNCTIONS.get(random.nextInt(SET_FUNCTIONS.size()));
325             byte[] valueArr = new byte[2 * length];
326             random.nextBytes(valueArr);
327             setAndCheck(setFunc, result1, result2, valueArr);
328 
329             // left could have been modified, so to turn it back into a summand
330             applyAndCheckMutable((a, b) -> a.setSquare(), left, right);
331 
332             ElemArrayFunction arrayFunc =
333                 ARRAY_FUNCTIONS.get(random.nextInt(ARRAY_FUNCTIONS.size()));
334             left.applyAndCheckArray(arrayFunc, right);
335         }
336     }
337 
338     // Run all the tests for a given field
runFieldTest(IntegerFieldModuloP testField, int length, int seed)339     static void runFieldTest(IntegerFieldModuloP testField,
340                              int length, int seed) {
341         System.out.println("Testing: " + testField.getClass().getSimpleName());
342 
343         Random random = new Random(seed);
344 
345         IntegerFieldModuloP baselineField =
346             new BigIntegerModuloP(testField.getSize());
347 
348         int numBits = testField.getSize().bitLength();
349         BigInteger r =
350             new BigInteger(numBits, random).mod(testField.getSize());
351         TestPair<IntegerModuloP> rand =
352             new TestPair(testField.getElement(r), baselineField.getElement(r));
353 
354         runOverflowTest(rand);
355 
356         // check combinations of operations for different kinds of elements
357         List<TestPair<IntegerModuloP>> testElements = new ArrayList<>();
358         testElements.add(rand);
359         testElements.add(new TestPair(testField.get0(), baselineField.get0()));
360         testElements.add(new TestPair(testField.get1(), baselineField.get1()));
361         byte[] testArr = {121, 37, -100, -5, 76, 33};
362         testElements.add(new TestPair(testField.getElement(testArr),
363             baselineField.getElement(testArr)));
364 
365         testArr = new byte[length];
366         random.nextBytes(testArr);
367         testElements.add(new TestPair(testField.getElement(testArr),
368             baselineField.getElement(testArr)));
369 
370         random.nextBytes(testArr);
371         byte highByte = (byte) (numBits > length * 8 ? 1 : 0);
372         testElements.add(
373             new TestPair(
374                 testField.getElement(testArr, 0, testArr.length, highByte),
375                 baselineField.getElement(testArr, 0, testArr.length, highByte)
376             )
377         );
378 
379         for (int i = 0; i < testElements.size(); i++) {
380             for (int j = 0; j < testElements.size(); j++) {
381                 runOperationsTest(random, length, testElements.get(i),
382                     testElements.get(j));
383             }
384         }
385     }
386 }
387 
388