1 /*
2  * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved.
3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4  *
5  * This code is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 only, as
7  * published by the Free Software Foundation.
8  *
9  * This code is distributed in the hope that it will be useful, but WITHOUT
10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12  * version 2 for more details (a copy is included in the LICENSE file that
13  * accompanied this code).
14  *
15  * You should have received a copy of the GNU General Public License version
16  * 2 along with this work; if not, write to the Free Software Foundation,
17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18  *
19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20  * or visit www.oracle.com if you need additional information or have any
21  * questions.
22  */
23 
24 /*
25  * @test
26  * @bug 8181594 8208648
27  * @summary Test proper operation of integer field arithmetic
28  * @modules java.base/sun.security.util java.base/sun.security.util.math java.base/sun.security.util.math.intpoly
29  * @build BigIntegerModuloP
30  * @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomial25519 32 0
31  * @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomial448 56 1
32  * @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomial1305 16 2
33  * @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomialP256 32 5
34  * @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomialP384 48 6
35  * @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomialP521 66 7
36  * @run main TestIntegerModuloP sun.security.util.math.intpoly.P256OrderField 32 8
37  * @run main TestIntegerModuloP sun.security.util.math.intpoly.P384OrderField 48 9
38  * @run main TestIntegerModuloP sun.security.util.math.intpoly.P521OrderField 66 10
39  */
40 
41 import sun.security.util.math.*;
42 import sun.security.util.math.intpoly.*;
43 import java.util.function.*;
44 
45 import java.util.*;
46 import java.math.*;
47 import java.nio.*;
48 
49 public class TestIntegerModuloP {
50 
51     static BigInteger TWO = BigInteger.valueOf(2);
52 
53     // The test has a list of functions, and it selects randomly from that list
54 
55     // The function types
56     interface ElemFunction extends BiFunction
57         <MutableIntegerModuloP, IntegerModuloP, IntegerModuloP> { }
58     interface ElemArrayFunction extends BiFunction
59         <MutableIntegerModuloP, IntegerModuloP, byte[]> { }
60     interface TriConsumer <T, U, V> {
accept(T t, U u, V v)61         void accept(T t, U u, V v);
62     }
63     interface ElemSetFunction extends TriConsumer
64         <MutableIntegerModuloP, IntegerModuloP, byte[]> { }
65 
66     // The lists of functions. Multiple lists are needed because the test
67     // respects the limitations of the arithmetic implementations.
68     static final List<ElemFunction> ADD_FUNCTIONS = new ArrayList<>();
69     static final List<ElemFunction> MULT_FUNCTIONS = new ArrayList<>();
70     static final List<ElemArrayFunction> ARRAY_FUNCTIONS = new ArrayList<>();
71     static final List<ElemSetFunction> SET_FUNCTIONS = new ArrayList<>();
72 
setUpFunctions(IntegerFieldModuloP field, int length)73     static void setUpFunctions(IntegerFieldModuloP field, int length) {
74 
75         ADD_FUNCTIONS.clear();
76         MULT_FUNCTIONS.clear();
77         SET_FUNCTIONS.clear();
78         ARRAY_FUNCTIONS.clear();
79 
80         byte highByte = (byte)
81             (field.getSize().bitLength() > length * 8 ? 1 : 0);
82 
83         // add functions are (im)mutable add/subtract
84         ADD_FUNCTIONS.add(IntegerModuloP::add);
85         ADD_FUNCTIONS.add(IntegerModuloP::subtract);
86         ADD_FUNCTIONS.add(MutableIntegerModuloP::setSum);
87         ADD_FUNCTIONS.add(MutableIntegerModuloP::setDifference);
88         // also include functions that return the first/second argument
89         ADD_FUNCTIONS.add((a, b) -> a);
90         ADD_FUNCTIONS.add((a, b) -> b);
91 
92         // mult functions are (im)mutable multiply and square
93         MULT_FUNCTIONS.add(IntegerModuloP::multiply);
94         MULT_FUNCTIONS.add((a, b) -> a.square());
95         MULT_FUNCTIONS.add((a, b) -> b.square());
96         MULT_FUNCTIONS.add(MutableIntegerModuloP::setProduct);
97         MULT_FUNCTIONS.add((a, b) -> a.setSquare());
98         // also test multiplication by a small value
99         MULT_FUNCTIONS.add((a, b) -> a.setProduct(b.getField().getSmallValue(
100             b.asBigInteger().mod(BigInteger.valueOf(262144)).intValue())));
101 
102         // set functions are setValue with various argument types
103         SET_FUNCTIONS.add((a, b, c) -> a.setValue(b));
104         SET_FUNCTIONS.add((a, b, c) ->
105             a.setValue(c, 0, c.length, (byte) 0));
106         SET_FUNCTIONS.add((a, b, c) ->
107             a.setValue(ByteBuffer.wrap(c, 0, c.length).order(ByteOrder.LITTLE_ENDIAN),
108             c.length, highByte));
109 
110         // array functions return the (possibly modified) value as byte array
111         ARRAY_FUNCTIONS.add((a, b ) -> a.asByteArray(length));
112         ARRAY_FUNCTIONS.add((a, b) -> a.addModPowerTwo(b, length));
113     }
114 
main(String[] args)115     public static void main(String[] args) {
116 
117         String className = args[0];
118         final int length = Integer.parseInt(args[1]);
119         int seed = Integer.parseInt(args[2]);
120 
121         Class<IntegerFieldModuloP> fieldBaseClass = IntegerFieldModuloP.class;
122         try {
123             Class<? extends IntegerFieldModuloP> clazz =
124                 Class.forName(className).asSubclass(fieldBaseClass);
125             IntegerFieldModuloP field =
126                 clazz.getDeclaredConstructor().newInstance();
127 
128             setUpFunctions(field, length);
129 
130             runFieldTest(field, length, seed);
131         } catch (Exception ex) {
132             throw new RuntimeException(ex);
133         }
134         System.out.println("All tests passed");
135     }
136 
assertEqual(IntegerModuloP e1, IntegerModuloP e2)137     static void assertEqual(IntegerModuloP e1, IntegerModuloP e2) {
138 
139         if (!e1.asBigInteger().equals(e2.asBigInteger())) {
140             throw new RuntimeException("values not equal: "
141                 + e1.asBigInteger() + " != " + e2.asBigInteger());
142         }
143     }
144 
145     // A class that holds pairs of actual/expected values, and allows
146     // computation on these pairs.
147     static class TestPair<T extends IntegerModuloP> {
148         private final T test;
149         private final T baseline;
150 
TestPair(T test, T baseline)151         public TestPair(T test, T baseline) {
152             this.test = test;
153             this.baseline = baseline;
154         }
155 
getTest()156         public T getTest() {
157             return test;
158         }
getBaseline()159         public T getBaseline() {
160             return baseline;
161         }
162 
assertEqual()163         private void assertEqual() {
164             TestIntegerModuloP.assertEqual(test, baseline);
165         }
166 
mutable()167         public TestPair<MutableIntegerModuloP> mutable() {
168             return new TestPair<>(test.mutable(), baseline.mutable());
169         }
170 
171         public
172         <R extends IntegerModuloP, X extends IntegerModuloP>
apply(BiFunction<T, R, X> func, TestPair<R> right)173         TestPair<X> apply(BiFunction<T, R, X> func, TestPair<R> right) {
174             X testResult = func.apply(test, right.test);
175             X baselineResult = func.apply(baseline, right.baseline);
176             return new TestPair(testResult, baselineResult);
177         }
178 
179         public
180         <U extends IntegerModuloP, V>
apply(TriConsumer<T, U, V> func, TestPair<U> right, V argV)181         void apply(TriConsumer<T, U, V> func, TestPair<U> right, V argV) {
182             func.accept(test, right.test, argV);
183             func.accept(baseline, right.baseline, argV);
184         }
185 
186         public
187         <R extends IntegerModuloP>
applyAndCheckArray(BiFunction<T, R, byte[]> func, TestPair<R> right)188         void applyAndCheckArray(BiFunction<T, R, byte[]> func,
189                                 TestPair<R> right) {
190             byte[] testResult = func.apply(test, right.test);
191             byte[] baselineResult = func.apply(baseline, right.baseline);
192             if (!Arrays.equals(testResult, baselineResult)) {
193                 throw new RuntimeException("Array values do not match: "
194                     + byteArrayToHexString(testResult) + " != "
195                     + byteArrayToHexString(baselineResult));
196             }
197         }
198 
199     }
200 
byteArrayToHexString(byte[] arr)201     static String byteArrayToHexString(byte[] arr) {
202         StringBuilder result = new StringBuilder();
203         for (int i = 0; i < arr.length; ++i) {
204             byte curVal = arr[i];
205             result.append(Character.forDigit(curVal >> 4 & 0xF, 16));
206             result.append(Character.forDigit(curVal & 0xF, 16));
207         }
208         return result.toString();
209     }
210 
211     static TestPair<IntegerModuloP>
applyAndCheck(ElemFunction func, TestPair<MutableIntegerModuloP> left, TestPair<IntegerModuloP> right)212     applyAndCheck(ElemFunction func, TestPair<MutableIntegerModuloP> left,
213                   TestPair<IntegerModuloP> right) {
214 
215         TestPair<IntegerModuloP> result = left.apply(func, right);
216         result.assertEqual();
217         left.assertEqual();
218         right.assertEqual();
219 
220         return result;
221     }
222 
223     static void
setAndCheck(ElemSetFunction func, TestPair<MutableIntegerModuloP> left, TestPair<IntegerModuloP> right, byte[] argV)224     setAndCheck(ElemSetFunction func, TestPair<MutableIntegerModuloP> left,
225                 TestPair<IntegerModuloP> right, byte[] argV) {
226 
227         left.apply(func, right, argV);
228         left.assertEqual();
229         right.assertEqual();
230     }
231 
232     static TestPair<MutableIntegerModuloP>
applyAndCheckMutable(ElemFunction func, TestPair<MutableIntegerModuloP> left, TestPair<IntegerModuloP> right)233     applyAndCheckMutable(ElemFunction func,
234                          TestPair<MutableIntegerModuloP> left,
235                          TestPair<IntegerModuloP> right) {
236 
237         TestPair<IntegerModuloP> result = applyAndCheck(func, left, right);
238 
239         TestPair<MutableIntegerModuloP> mutableResult = result.mutable();
240         mutableResult.assertEqual();
241         result.assertEqual();
242         left.assertEqual();
243         right.assertEqual();
244 
245         return mutableResult;
246     }
247 
248     static void
cswapAndCheck(int swap, TestPair<MutableIntegerModuloP> left, TestPair<MutableIntegerModuloP> right)249     cswapAndCheck(int swap, TestPair<MutableIntegerModuloP> left,
250                   TestPair<MutableIntegerModuloP> right) {
251 
252         left.getTest().conditionalSwapWith(right.getTest(), swap);
253         left.getBaseline().conditionalSwapWith(right.getBaseline(), swap);
254 
255         left.assertEqual();
256         right.assertEqual();
257 
258     }
259 
260     // Request arithmetic that should overflow, and ensure that overflow is
261     // detected.
runOverflowTest(TestPair<IntegerModuloP> elem)262     static void runOverflowTest(TestPair<IntegerModuloP> elem) {
263 
264         TestPair<MutableIntegerModuloP> mutableElem = elem.mutable();
265 
266         try {
267             for (int i = 0; i < 1000; i++) {
268                 applyAndCheck(MutableIntegerModuloP::setSum, mutableElem, elem);
269             }
270             applyAndCheck(MutableIntegerModuloP::setProduct, mutableElem, elem);
271         } catch (ArithmeticException ex) {
272             // this is expected
273         }
274 
275         mutableElem = elem.mutable();
276         try {
277             for (int i = 0; i < 1000; i++) {
278                 elem = applyAndCheck(IntegerModuloP::add,
279                     mutableElem, elem);
280             }
281             applyAndCheck(IntegerModuloP::multiply, mutableElem, elem);
282         } catch (ArithmeticException ex) {
283             // this is expected
284         }
285     }
286 
287     // Run a large number of random operations and ensure that
288     // results are correct
runOperationsTest(Random random, int length, TestPair<IntegerModuloP> elem, TestPair<IntegerModuloP> right)289     static void runOperationsTest(Random random, int length,
290                                   TestPair<IntegerModuloP> elem,
291                                   TestPair<IntegerModuloP> right) {
292 
293         TestPair<MutableIntegerModuloP> left = elem.mutable();
294 
295         for (int i = 0; i < 10000; i++) {
296 
297             ElemFunction addFunc1 =
298                 ADD_FUNCTIONS.get(random.nextInt(ADD_FUNCTIONS.size()));
299             TestPair<MutableIntegerModuloP> result1 =
300                 applyAndCheckMutable(addFunc1, left, right);
301 
302             // left could have been modified, so turn it back into a summand
303             applyAndCheckMutable((a, b) -> a.setSquare(), left, right);
304 
305             ElemFunction addFunc2 =
306                 ADD_FUNCTIONS.get(random.nextInt(ADD_FUNCTIONS.size()));
307             TestPair<IntegerModuloP> result2 =
308                 applyAndCheck(addFunc2, left, right);
309 
310             if (elem.test.getField() instanceof IntegerPolynomial) {
311                 IntegerPolynomial field =
312                     (IntegerPolynomial) elem.test.getField();
313                 int numAdds = field.getMaxAdds();
314                 for (int j = 1; j < numAdds; j++) {
315                     ElemFunction addFunc3 = ADD_FUNCTIONS.
316                         get(random.nextInt(ADD_FUNCTIONS.size()));
317                     result2 = applyAndCheck(addFunc3, left, right);
318                 }
319             }
320 
321             ElemFunction multFunc2 =
322                 MULT_FUNCTIONS.get(random.nextInt(MULT_FUNCTIONS.size()));
323             TestPair<MutableIntegerModuloP> multResult =
324                 applyAndCheckMutable(multFunc2, result1, result2);
325 
326             int swap = random.nextInt(2);
327             cswapAndCheck(swap, left, multResult);
328 
329             ElemSetFunction setFunc =
330                 SET_FUNCTIONS.get(random.nextInt(SET_FUNCTIONS.size()));
331             byte[] valueArr = new byte[length];
332             random.nextBytes(valueArr);
333             setAndCheck(setFunc, result1, result2, valueArr);
334 
335             // left could have been modified, so to turn it back into a summand
336             applyAndCheckMutable((a, b) -> a.setSquare(), left, right);
337 
338             ElemArrayFunction arrayFunc =
339                 ARRAY_FUNCTIONS.get(random.nextInt(ARRAY_FUNCTIONS.size()));
340             left.applyAndCheckArray(arrayFunc, right);
341         }
342     }
343 
344     // Run all the tests for a given field
runFieldTest(IntegerFieldModuloP testField, int length, int seed)345     static void runFieldTest(IntegerFieldModuloP testField,
346                              int length, int seed) {
347         System.out.println("Testing: " + testField.getClass().getSimpleName());
348 
349         Random random = new Random(seed);
350 
351         IntegerFieldModuloP baselineField =
352             new BigIntegerModuloP(testField.getSize());
353 
354         int numBits = testField.getSize().bitLength();
355         BigInteger r =
356             new BigInteger(numBits, random).mod(testField.getSize());
357         TestPair<IntegerModuloP> rand =
358             new TestPair(testField.getElement(r), baselineField.getElement(r));
359 
360         runOverflowTest(rand);
361 
362         // check combinations of operations for different kinds of elements
363         List<TestPair<IntegerModuloP>> testElements = new ArrayList<>();
364         testElements.add(rand);
365         testElements.add(new TestPair(testField.get0(), baselineField.get0()));
366         testElements.add(new TestPair(testField.get1(), baselineField.get1()));
367         byte[] testArr = {121, 37, -100, -5, 76, 33};
368         testElements.add(new TestPair(testField.getElement(testArr),
369             baselineField.getElement(testArr)));
370 
371         testArr = new byte[length];
372         random.nextBytes(testArr);
373         testElements.add(new TestPair(testField.getElement(testArr),
374             baselineField.getElement(testArr)));
375 
376         random.nextBytes(testArr);
377         byte highByte = (byte) (numBits > length * 8 ? 1 : 0);
378         testElements.add(
379             new TestPair(
380                 testField.getElement(testArr, 0, testArr.length, highByte),
381                 baselineField.getElement(testArr, 0, testArr.length, highByte)
382             )
383         );
384 
385         for (int i = 0; i < testElements.size(); i++) {
386             for (int j = 0; j < testElements.size(); j++) {
387                 runOperationsTest(random, length, testElements.get(i),
388                     testElements.get(j));
389             }
390         }
391     }
392 }
393 
394