1 /* 2 * Copyright (c) 2018, 2021, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 /* 25 * @test 26 * @bug 8181594 8208648 27 * @summary Test proper operation of integer field arithmetic 28 * @modules java.base/sun.security.util java.base/sun.security.util.math java.base/sun.security.util.math.intpoly 29 * @build BigIntegerModuloP 30 * @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomial25519 32 0 31 * @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomial448 56 1 32 * @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomial1305 16 2 33 * @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomialP256 32 5 34 * @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomialP384 48 6 35 * @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomialP521 66 7 36 * @run main TestIntegerModuloP sun.security.util.math.intpoly.P256OrderField 32 8 37 * @run main TestIntegerModuloP sun.security.util.math.intpoly.P384OrderField 48 9 38 * @run main TestIntegerModuloP sun.security.util.math.intpoly.P521OrderField 66 10 39 * @run main TestIntegerModuloP sun.security.util.math.intpoly.Curve25519OrderField 32 11 40 * @run main TestIntegerModuloP sun.security.util.math.intpoly.Curve448OrderField 56 12 41 */ 42 43 import sun.security.util.math.*; 44 import sun.security.util.math.intpoly.*; 45 import java.util.function.*; 46 47 import java.util.*; 48 import java.math.*; 49 import java.nio.*; 50 51 public class TestIntegerModuloP { 52 53 static BigInteger TWO = BigInteger.valueOf(2); 54 55 // The test has a list of functions, and it selects randomly from that list 56 57 // The function types 58 interface ElemFunction extends BiFunction 59 <MutableIntegerModuloP, IntegerModuloP, IntegerModuloP> { } 60 interface ElemArrayFunction extends BiFunction 61 <MutableIntegerModuloP, IntegerModuloP, byte[]> { } 62 interface TriConsumer <T, U, V> { accept(T t, U u, V v)63 void accept(T t, U u, V v); 64 } 65 interface ElemSetFunction extends TriConsumer 66 <MutableIntegerModuloP, IntegerModuloP, byte[]> { } 67 68 // The lists of functions. Multiple lists are needed because the test 69 // respects the limitations of the arithmetic implementations. 70 static final List<ElemFunction> ADD_FUNCTIONS = new ArrayList<>(); 71 static final List<ElemFunction> MULT_FUNCTIONS = new ArrayList<>(); 72 static final List<ElemArrayFunction> ARRAY_FUNCTIONS = new ArrayList<>(); 73 static final List<ElemSetFunction> SET_FUNCTIONS = new ArrayList<>(); 74 setUpFunctions(IntegerFieldModuloP field, int length)75 static void setUpFunctions(IntegerFieldModuloP field, int length) { 76 77 ADD_FUNCTIONS.clear(); 78 MULT_FUNCTIONS.clear(); 79 SET_FUNCTIONS.clear(); 80 ARRAY_FUNCTIONS.clear(); 81 82 byte highByte = (byte) 83 (field.getSize().bitLength() > length * 8 ? 1 : 0); 84 85 // add functions are (im)mutable add/subtract 86 ADD_FUNCTIONS.add(IntegerModuloP::add); 87 ADD_FUNCTIONS.add(IntegerModuloP::subtract); 88 ADD_FUNCTIONS.add(MutableIntegerModuloP::setSum); 89 ADD_FUNCTIONS.add(MutableIntegerModuloP::setDifference); 90 // also include functions that return the first/second argument 91 ADD_FUNCTIONS.add((a, b) -> a); 92 ADD_FUNCTIONS.add((a, b) -> b); 93 94 // mult functions are (im)mutable multiply and square 95 MULT_FUNCTIONS.add(IntegerModuloP::multiply); 96 MULT_FUNCTIONS.add((a, b) -> a.square()); 97 MULT_FUNCTIONS.add((a, b) -> b.square()); 98 MULT_FUNCTIONS.add(MutableIntegerModuloP::setProduct); 99 MULT_FUNCTIONS.add((a, b) -> a.setSquare()); 100 // also test multiplication by a small value 101 MULT_FUNCTIONS.add((a, b) -> a.setProduct(b.getField().getSmallValue( 102 b.asBigInteger().mod(BigInteger.valueOf(262144)).intValue()))); 103 104 // set functions are setValue with various argument types 105 SET_FUNCTIONS.add((a, b, c) -> a.setValue(b)); 106 SET_FUNCTIONS.add((a, b, c) -> 107 a.setValue(c, 0, c.length, (byte) 0)); 108 SET_FUNCTIONS.add((a, b, c) -> 109 a.setValue(c, 0, c.length / 2, (byte) 0)); 110 SET_FUNCTIONS.add((a, b, c) -> 111 a.setValue(ByteBuffer.wrap(c, 0, c.length / 2).order(ByteOrder.LITTLE_ENDIAN), 112 c.length / 2, highByte)); 113 114 // array functions return the (possibly modified) value as byte array 115 ARRAY_FUNCTIONS.add((a, b ) -> a.asByteArray(length)); 116 ARRAY_FUNCTIONS.add((a, b) -> a.addModPowerTwo(b, length)); 117 } 118 main(String[] args)119 public static void main(String[] args) { 120 121 String className = args[0]; 122 final int length = Integer.parseInt(args[1]); 123 int seed = Integer.parseInt(args[2]); 124 125 Class<IntegerFieldModuloP> fieldBaseClass = IntegerFieldModuloP.class; 126 try { 127 Class<? extends IntegerFieldModuloP> clazz = 128 Class.forName(className).asSubclass(fieldBaseClass); 129 IntegerFieldModuloP field = 130 clazz.getDeclaredConstructor().newInstance(); 131 132 setUpFunctions(field, length); 133 134 runFieldTest(field, length, seed); 135 } catch (Exception ex) { 136 throw new RuntimeException(ex); 137 } 138 System.out.println("All tests passed"); 139 } 140 assertEqual(IntegerModuloP e1, IntegerModuloP e2)141 static void assertEqual(IntegerModuloP e1, IntegerModuloP e2) { 142 143 if (!e1.asBigInteger().equals(e2.asBigInteger())) { 144 throw new RuntimeException("values not equal: " 145 + e1.asBigInteger() + " != " + e2.asBigInteger()); 146 } 147 } 148 149 // A class that holds pairs of actual/expected values, and allows 150 // computation on these pairs. 151 static class TestPair<T extends IntegerModuloP> { 152 private final T test; 153 private final T baseline; 154 TestPair(T test, T baseline)155 public TestPair(T test, T baseline) { 156 this.test = test; 157 this.baseline = baseline; 158 } 159 getTest()160 public T getTest() { 161 return test; 162 } getBaseline()163 public T getBaseline() { 164 return baseline; 165 } 166 assertEqual()167 private void assertEqual() { 168 TestIntegerModuloP.assertEqual(test, baseline); 169 } 170 mutable()171 public TestPair<MutableIntegerModuloP> mutable() { 172 return new TestPair<>(test.mutable(), baseline.mutable()); 173 } 174 175 public 176 <R extends IntegerModuloP, X extends IntegerModuloP> apply(BiFunction<T, R, X> func, TestPair<R> right)177 TestPair<X> apply(BiFunction<T, R, X> func, TestPair<R> right) { 178 X testResult = func.apply(test, right.test); 179 X baselineResult = func.apply(baseline, right.baseline); 180 return new TestPair(testResult, baselineResult); 181 } 182 183 public 184 <U extends IntegerModuloP, V> apply(TriConsumer<T, U, V> func, TestPair<U> right, V argV)185 void apply(TriConsumer<T, U, V> func, TestPair<U> right, V argV) { 186 func.accept(test, right.test, argV); 187 func.accept(baseline, right.baseline, argV); 188 } 189 190 public 191 <R extends IntegerModuloP> applyAndCheckArray(BiFunction<T, R, byte[]> func, TestPair<R> right)192 void applyAndCheckArray(BiFunction<T, R, byte[]> func, 193 TestPair<R> right) { 194 byte[] testResult = func.apply(test, right.test); 195 byte[] baselineResult = func.apply(baseline, right.baseline); 196 if (!Arrays.equals(testResult, baselineResult)) { 197 throw new RuntimeException("Array values do not match: " 198 + HexFormat.of().withUpperCase().formatHex(testResult) + " != " 199 + HexFormat.of().withUpperCase().formatHex(baselineResult)); 200 } 201 } 202 203 } 204 205 static TestPair<IntegerModuloP> applyAndCheck(ElemFunction func, TestPair<MutableIntegerModuloP> left, TestPair<IntegerModuloP> right)206 applyAndCheck(ElemFunction func, TestPair<MutableIntegerModuloP> left, 207 TestPair<IntegerModuloP> right) { 208 209 TestPair<IntegerModuloP> result = left.apply(func, right); 210 result.assertEqual(); 211 left.assertEqual(); 212 right.assertEqual(); 213 214 return result; 215 } 216 217 static void setAndCheck(ElemSetFunction func, TestPair<MutableIntegerModuloP> left, TestPair<IntegerModuloP> right, byte[] argV)218 setAndCheck(ElemSetFunction func, TestPair<MutableIntegerModuloP> left, 219 TestPair<IntegerModuloP> right, byte[] argV) { 220 221 left.apply(func, right, argV); 222 left.assertEqual(); 223 right.assertEqual(); 224 } 225 226 static TestPair<MutableIntegerModuloP> applyAndCheckMutable(ElemFunction func, TestPair<MutableIntegerModuloP> left, TestPair<IntegerModuloP> right)227 applyAndCheckMutable(ElemFunction func, 228 TestPair<MutableIntegerModuloP> left, 229 TestPair<IntegerModuloP> right) { 230 231 TestPair<IntegerModuloP> result = applyAndCheck(func, left, right); 232 233 TestPair<MutableIntegerModuloP> mutableResult = result.mutable(); 234 mutableResult.assertEqual(); 235 result.assertEqual(); 236 left.assertEqual(); 237 right.assertEqual(); 238 239 return mutableResult; 240 } 241 242 static void cswapAndCheck(int swap, TestPair<MutableIntegerModuloP> left, TestPair<MutableIntegerModuloP> right)243 cswapAndCheck(int swap, TestPair<MutableIntegerModuloP> left, 244 TestPair<MutableIntegerModuloP> right) { 245 246 left.getTest().conditionalSwapWith(right.getTest(), swap); 247 left.getBaseline().conditionalSwapWith(right.getBaseline(), swap); 248 249 left.assertEqual(); 250 right.assertEqual(); 251 252 } 253 254 // Request arithmetic that should overflow, and ensure that overflow is 255 // detected. runOverflowTest(TestPair<IntegerModuloP> elem)256 static void runOverflowTest(TestPair<IntegerModuloP> elem) { 257 258 TestPair<MutableIntegerModuloP> mutableElem = elem.mutable(); 259 260 try { 261 for (int i = 0; i < 1000; i++) { 262 applyAndCheck(MutableIntegerModuloP::setSum, mutableElem, elem); 263 } 264 applyAndCheck(MutableIntegerModuloP::setProduct, mutableElem, elem); 265 } catch (ArithmeticException ex) { 266 // this is expected 267 } 268 269 mutableElem = elem.mutable(); 270 try { 271 for (int i = 0; i < 1000; i++) { 272 elem = applyAndCheck(IntegerModuloP::add, 273 mutableElem, elem); 274 } 275 applyAndCheck(IntegerModuloP::multiply, mutableElem, elem); 276 } catch (ArithmeticException ex) { 277 // this is expected 278 } 279 } 280 281 // Run a large number of random operations and ensure that 282 // results are correct runOperationsTest(Random random, int length, TestPair<IntegerModuloP> elem, TestPair<IntegerModuloP> right)283 static void runOperationsTest(Random random, int length, 284 TestPair<IntegerModuloP> elem, 285 TestPair<IntegerModuloP> right) { 286 287 TestPair<MutableIntegerModuloP> left = elem.mutable(); 288 289 for (int i = 0; i < 10000; i++) { 290 291 ElemFunction addFunc1 = 292 ADD_FUNCTIONS.get(random.nextInt(ADD_FUNCTIONS.size())); 293 TestPair<MutableIntegerModuloP> result1 = 294 applyAndCheckMutable(addFunc1, left, right); 295 296 // left could have been modified, so turn it back into a summand 297 applyAndCheckMutable((a, b) -> a.setSquare(), left, right); 298 299 ElemFunction addFunc2 = 300 ADD_FUNCTIONS.get(random.nextInt(ADD_FUNCTIONS.size())); 301 TestPair<IntegerModuloP> result2 = 302 applyAndCheck(addFunc2, left, right); 303 304 if (elem.test.getField() instanceof IntegerPolynomial) { 305 IntegerPolynomial field = 306 (IntegerPolynomial) elem.test.getField(); 307 int numAdds = field.getMaxAdds(); 308 for (int j = 1; j < numAdds; j++) { 309 ElemFunction addFunc3 = ADD_FUNCTIONS. 310 get(random.nextInt(ADD_FUNCTIONS.size())); 311 result2 = applyAndCheck(addFunc3, left, right); 312 } 313 } 314 315 ElemFunction multFunc2 = 316 MULT_FUNCTIONS.get(random.nextInt(MULT_FUNCTIONS.size())); 317 TestPair<MutableIntegerModuloP> multResult = 318 applyAndCheckMutable(multFunc2, result1, result2); 319 320 int swap = random.nextInt(2); 321 cswapAndCheck(swap, left, multResult); 322 323 ElemSetFunction setFunc = 324 SET_FUNCTIONS.get(random.nextInt(SET_FUNCTIONS.size())); 325 byte[] valueArr = new byte[2 * length]; 326 random.nextBytes(valueArr); 327 setAndCheck(setFunc, result1, result2, valueArr); 328 329 // left could have been modified, so to turn it back into a summand 330 applyAndCheckMutable((a, b) -> a.setSquare(), left, right); 331 332 ElemArrayFunction arrayFunc = 333 ARRAY_FUNCTIONS.get(random.nextInt(ARRAY_FUNCTIONS.size())); 334 left.applyAndCheckArray(arrayFunc, right); 335 } 336 } 337 338 // Run all the tests for a given field runFieldTest(IntegerFieldModuloP testField, int length, int seed)339 static void runFieldTest(IntegerFieldModuloP testField, 340 int length, int seed) { 341 System.out.println("Testing: " + testField.getClass().getSimpleName()); 342 343 Random random = new Random(seed); 344 345 IntegerFieldModuloP baselineField = 346 new BigIntegerModuloP(testField.getSize()); 347 348 int numBits = testField.getSize().bitLength(); 349 BigInteger r = 350 new BigInteger(numBits, random).mod(testField.getSize()); 351 TestPair<IntegerModuloP> rand = 352 new TestPair(testField.getElement(r), baselineField.getElement(r)); 353 354 runOverflowTest(rand); 355 356 // check combinations of operations for different kinds of elements 357 List<TestPair<IntegerModuloP>> testElements = new ArrayList<>(); 358 testElements.add(rand); 359 testElements.add(new TestPair(testField.get0(), baselineField.get0())); 360 testElements.add(new TestPair(testField.get1(), baselineField.get1())); 361 byte[] testArr = {121, 37, -100, -5, 76, 33}; 362 testElements.add(new TestPair(testField.getElement(testArr), 363 baselineField.getElement(testArr))); 364 365 testArr = new byte[length]; 366 random.nextBytes(testArr); 367 testElements.add(new TestPair(testField.getElement(testArr), 368 baselineField.getElement(testArr))); 369 370 random.nextBytes(testArr); 371 byte highByte = (byte) (numBits > length * 8 ? 1 : 0); 372 testElements.add( 373 new TestPair( 374 testField.getElement(testArr, 0, testArr.length, highByte), 375 baselineField.getElement(testArr, 0, testArr.length, highByte) 376 ) 377 ); 378 379 for (int i = 0; i < testElements.size(); i++) { 380 for (int j = 0; j < testElements.size(); j++) { 381 runOperationsTest(random, length, testElements.get(i), 382 testElements.get(j)); 383 } 384 } 385 } 386 } 387 388