1 /* 2 * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 /* 25 * @test 26 * @bug 8181594 8208648 27 * @summary Test proper operation of integer field arithmetic 28 * @modules java.base/sun.security.util java.base/sun.security.util.math java.base/sun.security.util.math.intpoly 29 * @build BigIntegerModuloP 30 * @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomial25519 32 0 31 * @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomial448 56 1 32 * @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomial1305 16 2 33 * @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomialP256 32 5 34 * @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomialP384 48 6 35 * @run main TestIntegerModuloP sun.security.util.math.intpoly.IntegerPolynomialP521 66 7 36 * @run main TestIntegerModuloP sun.security.util.math.intpoly.P256OrderField 32 8 37 * @run main TestIntegerModuloP sun.security.util.math.intpoly.P384OrderField 48 9 38 * @run main TestIntegerModuloP sun.security.util.math.intpoly.P521OrderField 66 10 39 */ 40 41 import sun.security.util.math.*; 42 import sun.security.util.math.intpoly.*; 43 import java.util.function.*; 44 45 import java.util.*; 46 import java.math.*; 47 import java.nio.*; 48 49 public class TestIntegerModuloP { 50 51 static BigInteger TWO = BigInteger.valueOf(2); 52 53 // The test has a list of functions, and it selects randomly from that list 54 55 // The function types 56 interface ElemFunction extends BiFunction 57 <MutableIntegerModuloP, IntegerModuloP, IntegerModuloP> { } 58 interface ElemArrayFunction extends BiFunction 59 <MutableIntegerModuloP, IntegerModuloP, byte[]> { } 60 interface TriConsumer <T, U, V> { accept(T t, U u, V v)61 void accept(T t, U u, V v); 62 } 63 interface ElemSetFunction extends TriConsumer 64 <MutableIntegerModuloP, IntegerModuloP, byte[]> { } 65 66 // The lists of functions. Multiple lists are needed because the test 67 // respects the limitations of the arithmetic implementations. 68 static final List<ElemFunction> ADD_FUNCTIONS = new ArrayList<>(); 69 static final List<ElemFunction> MULT_FUNCTIONS = new ArrayList<>(); 70 static final List<ElemArrayFunction> ARRAY_FUNCTIONS = new ArrayList<>(); 71 static final List<ElemSetFunction> SET_FUNCTIONS = new ArrayList<>(); 72 setUpFunctions(IntegerFieldModuloP field, int length)73 static void setUpFunctions(IntegerFieldModuloP field, int length) { 74 75 ADD_FUNCTIONS.clear(); 76 MULT_FUNCTIONS.clear(); 77 SET_FUNCTIONS.clear(); 78 ARRAY_FUNCTIONS.clear(); 79 80 byte highByte = (byte) 81 (field.getSize().bitLength() > length * 8 ? 1 : 0); 82 83 // add functions are (im)mutable add/subtract 84 ADD_FUNCTIONS.add(IntegerModuloP::add); 85 ADD_FUNCTIONS.add(IntegerModuloP::subtract); 86 ADD_FUNCTIONS.add(MutableIntegerModuloP::setSum); 87 ADD_FUNCTIONS.add(MutableIntegerModuloP::setDifference); 88 // also include functions that return the first/second argument 89 ADD_FUNCTIONS.add((a, b) -> a); 90 ADD_FUNCTIONS.add((a, b) -> b); 91 92 // mult functions are (im)mutable multiply and square 93 MULT_FUNCTIONS.add(IntegerModuloP::multiply); 94 MULT_FUNCTIONS.add((a, b) -> a.square()); 95 MULT_FUNCTIONS.add((a, b) -> b.square()); 96 MULT_FUNCTIONS.add(MutableIntegerModuloP::setProduct); 97 MULT_FUNCTIONS.add((a, b) -> a.setSquare()); 98 // also test multiplication by a small value 99 MULT_FUNCTIONS.add((a, b) -> a.setProduct(b.getField().getSmallValue( 100 b.asBigInteger().mod(BigInteger.valueOf(262144)).intValue()))); 101 102 // set functions are setValue with various argument types 103 SET_FUNCTIONS.add((a, b, c) -> a.setValue(b)); 104 SET_FUNCTIONS.add((a, b, c) -> 105 a.setValue(c, 0, c.length, (byte) 0)); 106 SET_FUNCTIONS.add((a, b, c) -> 107 a.setValue(ByteBuffer.wrap(c, 0, c.length).order(ByteOrder.LITTLE_ENDIAN), 108 c.length, highByte)); 109 110 // array functions return the (possibly modified) value as byte array 111 ARRAY_FUNCTIONS.add((a, b ) -> a.asByteArray(length)); 112 ARRAY_FUNCTIONS.add((a, b) -> a.addModPowerTwo(b, length)); 113 } 114 main(String[] args)115 public static void main(String[] args) { 116 117 String className = args[0]; 118 final int length = Integer.parseInt(args[1]); 119 int seed = Integer.parseInt(args[2]); 120 121 Class<IntegerFieldModuloP> fieldBaseClass = IntegerFieldModuloP.class; 122 try { 123 Class<? extends IntegerFieldModuloP> clazz = 124 Class.forName(className).asSubclass(fieldBaseClass); 125 IntegerFieldModuloP field = 126 clazz.getDeclaredConstructor().newInstance(); 127 128 setUpFunctions(field, length); 129 130 runFieldTest(field, length, seed); 131 } catch (Exception ex) { 132 throw new RuntimeException(ex); 133 } 134 System.out.println("All tests passed"); 135 } 136 assertEqual(IntegerModuloP e1, IntegerModuloP e2)137 static void assertEqual(IntegerModuloP e1, IntegerModuloP e2) { 138 139 if (!e1.asBigInteger().equals(e2.asBigInteger())) { 140 throw new RuntimeException("values not equal: " 141 + e1.asBigInteger() + " != " + e2.asBigInteger()); 142 } 143 } 144 145 // A class that holds pairs of actual/expected values, and allows 146 // computation on these pairs. 147 static class TestPair<T extends IntegerModuloP> { 148 private final T test; 149 private final T baseline; 150 TestPair(T test, T baseline)151 public TestPair(T test, T baseline) { 152 this.test = test; 153 this.baseline = baseline; 154 } 155 getTest()156 public T getTest() { 157 return test; 158 } getBaseline()159 public T getBaseline() { 160 return baseline; 161 } 162 assertEqual()163 private void assertEqual() { 164 TestIntegerModuloP.assertEqual(test, baseline); 165 } 166 mutable()167 public TestPair<MutableIntegerModuloP> mutable() { 168 return new TestPair<>(test.mutable(), baseline.mutable()); 169 } 170 171 public 172 <R extends IntegerModuloP, X extends IntegerModuloP> apply(BiFunction<T, R, X> func, TestPair<R> right)173 TestPair<X> apply(BiFunction<T, R, X> func, TestPair<R> right) { 174 X testResult = func.apply(test, right.test); 175 X baselineResult = func.apply(baseline, right.baseline); 176 return new TestPair(testResult, baselineResult); 177 } 178 179 public 180 <U extends IntegerModuloP, V> apply(TriConsumer<T, U, V> func, TestPair<U> right, V argV)181 void apply(TriConsumer<T, U, V> func, TestPair<U> right, V argV) { 182 func.accept(test, right.test, argV); 183 func.accept(baseline, right.baseline, argV); 184 } 185 186 public 187 <R extends IntegerModuloP> applyAndCheckArray(BiFunction<T, R, byte[]> func, TestPair<R> right)188 void applyAndCheckArray(BiFunction<T, R, byte[]> func, 189 TestPair<R> right) { 190 byte[] testResult = func.apply(test, right.test); 191 byte[] baselineResult = func.apply(baseline, right.baseline); 192 if (!Arrays.equals(testResult, baselineResult)) { 193 throw new RuntimeException("Array values do not match: " 194 + byteArrayToHexString(testResult) + " != " 195 + byteArrayToHexString(baselineResult)); 196 } 197 } 198 199 } 200 byteArrayToHexString(byte[] arr)201 static String byteArrayToHexString(byte[] arr) { 202 StringBuilder result = new StringBuilder(); 203 for (int i = 0; i < arr.length; ++i) { 204 byte curVal = arr[i]; 205 result.append(Character.forDigit(curVal >> 4 & 0xF, 16)); 206 result.append(Character.forDigit(curVal & 0xF, 16)); 207 } 208 return result.toString(); 209 } 210 211 static TestPair<IntegerModuloP> applyAndCheck(ElemFunction func, TestPair<MutableIntegerModuloP> left, TestPair<IntegerModuloP> right)212 applyAndCheck(ElemFunction func, TestPair<MutableIntegerModuloP> left, 213 TestPair<IntegerModuloP> right) { 214 215 TestPair<IntegerModuloP> result = left.apply(func, right); 216 result.assertEqual(); 217 left.assertEqual(); 218 right.assertEqual(); 219 220 return result; 221 } 222 223 static void setAndCheck(ElemSetFunction func, TestPair<MutableIntegerModuloP> left, TestPair<IntegerModuloP> right, byte[] argV)224 setAndCheck(ElemSetFunction func, TestPair<MutableIntegerModuloP> left, 225 TestPair<IntegerModuloP> right, byte[] argV) { 226 227 left.apply(func, right, argV); 228 left.assertEqual(); 229 right.assertEqual(); 230 } 231 232 static TestPair<MutableIntegerModuloP> applyAndCheckMutable(ElemFunction func, TestPair<MutableIntegerModuloP> left, TestPair<IntegerModuloP> right)233 applyAndCheckMutable(ElemFunction func, 234 TestPair<MutableIntegerModuloP> left, 235 TestPair<IntegerModuloP> right) { 236 237 TestPair<IntegerModuloP> result = applyAndCheck(func, left, right); 238 239 TestPair<MutableIntegerModuloP> mutableResult = result.mutable(); 240 mutableResult.assertEqual(); 241 result.assertEqual(); 242 left.assertEqual(); 243 right.assertEqual(); 244 245 return mutableResult; 246 } 247 248 static void cswapAndCheck(int swap, TestPair<MutableIntegerModuloP> left, TestPair<MutableIntegerModuloP> right)249 cswapAndCheck(int swap, TestPair<MutableIntegerModuloP> left, 250 TestPair<MutableIntegerModuloP> right) { 251 252 left.getTest().conditionalSwapWith(right.getTest(), swap); 253 left.getBaseline().conditionalSwapWith(right.getBaseline(), swap); 254 255 left.assertEqual(); 256 right.assertEqual(); 257 258 } 259 260 // Request arithmetic that should overflow, and ensure that overflow is 261 // detected. runOverflowTest(TestPair<IntegerModuloP> elem)262 static void runOverflowTest(TestPair<IntegerModuloP> elem) { 263 264 TestPair<MutableIntegerModuloP> mutableElem = elem.mutable(); 265 266 try { 267 for (int i = 0; i < 1000; i++) { 268 applyAndCheck(MutableIntegerModuloP::setSum, mutableElem, elem); 269 } 270 applyAndCheck(MutableIntegerModuloP::setProduct, mutableElem, elem); 271 } catch (ArithmeticException ex) { 272 // this is expected 273 } 274 275 mutableElem = elem.mutable(); 276 try { 277 for (int i = 0; i < 1000; i++) { 278 elem = applyAndCheck(IntegerModuloP::add, 279 mutableElem, elem); 280 } 281 applyAndCheck(IntegerModuloP::multiply, mutableElem, elem); 282 } catch (ArithmeticException ex) { 283 // this is expected 284 } 285 } 286 287 // Run a large number of random operations and ensure that 288 // results are correct runOperationsTest(Random random, int length, TestPair<IntegerModuloP> elem, TestPair<IntegerModuloP> right)289 static void runOperationsTest(Random random, int length, 290 TestPair<IntegerModuloP> elem, 291 TestPair<IntegerModuloP> right) { 292 293 TestPair<MutableIntegerModuloP> left = elem.mutable(); 294 295 for (int i = 0; i < 10000; i++) { 296 297 ElemFunction addFunc1 = 298 ADD_FUNCTIONS.get(random.nextInt(ADD_FUNCTIONS.size())); 299 TestPair<MutableIntegerModuloP> result1 = 300 applyAndCheckMutable(addFunc1, left, right); 301 302 // left could have been modified, so turn it back into a summand 303 applyAndCheckMutable((a, b) -> a.setSquare(), left, right); 304 305 ElemFunction addFunc2 = 306 ADD_FUNCTIONS.get(random.nextInt(ADD_FUNCTIONS.size())); 307 TestPair<IntegerModuloP> result2 = 308 applyAndCheck(addFunc2, left, right); 309 310 if (elem.test.getField() instanceof IntegerPolynomial) { 311 IntegerPolynomial field = 312 (IntegerPolynomial) elem.test.getField(); 313 int numAdds = field.getMaxAdds(); 314 for (int j = 1; j < numAdds; j++) { 315 ElemFunction addFunc3 = ADD_FUNCTIONS. 316 get(random.nextInt(ADD_FUNCTIONS.size())); 317 result2 = applyAndCheck(addFunc3, left, right); 318 } 319 } 320 321 ElemFunction multFunc2 = 322 MULT_FUNCTIONS.get(random.nextInt(MULT_FUNCTIONS.size())); 323 TestPair<MutableIntegerModuloP> multResult = 324 applyAndCheckMutable(multFunc2, result1, result2); 325 326 int swap = random.nextInt(2); 327 cswapAndCheck(swap, left, multResult); 328 329 ElemSetFunction setFunc = 330 SET_FUNCTIONS.get(random.nextInt(SET_FUNCTIONS.size())); 331 byte[] valueArr = new byte[length]; 332 random.nextBytes(valueArr); 333 setAndCheck(setFunc, result1, result2, valueArr); 334 335 // left could have been modified, so to turn it back into a summand 336 applyAndCheckMutable((a, b) -> a.setSquare(), left, right); 337 338 ElemArrayFunction arrayFunc = 339 ARRAY_FUNCTIONS.get(random.nextInt(ARRAY_FUNCTIONS.size())); 340 left.applyAndCheckArray(arrayFunc, right); 341 } 342 } 343 344 // Run all the tests for a given field runFieldTest(IntegerFieldModuloP testField, int length, int seed)345 static void runFieldTest(IntegerFieldModuloP testField, 346 int length, int seed) { 347 System.out.println("Testing: " + testField.getClass().getSimpleName()); 348 349 Random random = new Random(seed); 350 351 IntegerFieldModuloP baselineField = 352 new BigIntegerModuloP(testField.getSize()); 353 354 int numBits = testField.getSize().bitLength(); 355 BigInteger r = 356 new BigInteger(numBits, random).mod(testField.getSize()); 357 TestPair<IntegerModuloP> rand = 358 new TestPair(testField.getElement(r), baselineField.getElement(r)); 359 360 runOverflowTest(rand); 361 362 // check combinations of operations for different kinds of elements 363 List<TestPair<IntegerModuloP>> testElements = new ArrayList<>(); 364 testElements.add(rand); 365 testElements.add(new TestPair(testField.get0(), baselineField.get0())); 366 testElements.add(new TestPair(testField.get1(), baselineField.get1())); 367 byte[] testArr = {121, 37, -100, -5, 76, 33}; 368 testElements.add(new TestPair(testField.getElement(testArr), 369 baselineField.getElement(testArr))); 370 371 testArr = new byte[length]; 372 random.nextBytes(testArr); 373 testElements.add(new TestPair(testField.getElement(testArr), 374 baselineField.getElement(testArr))); 375 376 random.nextBytes(testArr); 377 byte highByte = (byte) (numBits > length * 8 ? 1 : 0); 378 testElements.add( 379 new TestPair( 380 testField.getElement(testArr, 0, testArr.length, highByte), 381 baselineField.getElement(testArr, 0, testArr.length, highByte) 382 ) 383 ); 384 385 for (int i = 0; i < testElements.size(); i++) { 386 for (int j = 0; j < testElements.size(); j++) { 387 runOperationsTest(random, length, testElements.get(i), 388 testElements.get(j)); 389 } 390 } 391 } 392 } 393 394