1 /* 2 * Copyright (c) 2018, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. Oracle designates this 8 * particular file as subject to the "Classpath" exception as provided 9 * by Oracle in the LICENSE file that accompanied this code. 10 * 11 * This code is distributed in the hope that it will be useful, but WITHOUT 12 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 13 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 14 * version 2 for more details (a copy is included in the LICENSE file that 15 * accompanied this code). 16 * 17 * You should have received a copy of the GNU General Public License version 18 * 2 along with this work; if not, write to the Free Software Foundation, 19 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 20 * 21 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 22 * or visit www.oracle.com if you need additional information or have any 23 * questions. 24 */ 25 26 package sun.security.ec; 27 28 import sun.security.ec.point.*; 29 import sun.security.util.math.*; 30 import sun.security.util.math.intpoly.*; 31 32 import java.math.BigInteger; 33 import java.security.ProviderException; 34 import java.security.spec.ECFieldFp; 35 import java.security.spec.ECParameterSpec; 36 import java.security.spec.EllipticCurve; 37 import java.util.Map; 38 import java.util.Optional; 39 40 /* 41 * Elliptic curve point arithmetic for prime-order curves where a=-3. 42 * Formulas are derived from "Complete addition formulas for prime order 43 * elliptic curves" by Renes, Costello, and Batina. 44 */ 45 46 public class ECOperations { 47 48 /* 49 * An exception indicating a problem with an intermediate value produced 50 * by some part of the computation. For example, the signing operation 51 * will throw this exception to indicate that the r or s value is 0, and 52 * that the signing operation should be tried again with a different nonce. 53 */ 54 static class IntermediateValueException extends Exception { 55 private static final long serialVersionUID = 1; 56 } 57 58 static final Map<BigInteger, IntegerFieldModuloP> fields = Map.of( 59 IntegerPolynomialP256.MODULUS, new IntegerPolynomialP256(), 60 IntegerPolynomialP384.MODULUS, new IntegerPolynomialP384(), 61 IntegerPolynomialP521.MODULUS, new IntegerPolynomialP521() 62 ); 63 64 static final Map<BigInteger, IntegerFieldModuloP> orderFields = Map.of( 65 P256OrderField.MODULUS, new P256OrderField(), 66 P384OrderField.MODULUS, new P384OrderField(), 67 P521OrderField.MODULUS, new P521OrderField() 68 ); 69 forParameters(ECParameterSpec params)70 public static Optional<ECOperations> forParameters(ECParameterSpec params) { 71 72 EllipticCurve curve = params.getCurve(); 73 if (!(curve.getField() instanceof ECFieldFp)) { 74 return Optional.empty(); 75 } 76 ECFieldFp primeField = (ECFieldFp) curve.getField(); 77 78 BigInteger three = BigInteger.valueOf(3); 79 if (!primeField.getP().subtract(curve.getA()).equals(three)) { 80 return Optional.empty(); 81 } 82 IntegerFieldModuloP field = fields.get(primeField.getP()); 83 if (field == null) { 84 return Optional.empty(); 85 } 86 87 IntegerFieldModuloP orderField = orderFields.get(params.getOrder()); 88 if (orderField == null) { 89 return Optional.empty(); 90 } 91 92 ImmutableIntegerModuloP b = field.getElement(curve.getB()); 93 ECOperations ecOps = new ECOperations(b, orderField); 94 return Optional.of(ecOps); 95 } 96 97 final ImmutableIntegerModuloP b; 98 final SmallValue one; 99 final SmallValue two; 100 final SmallValue three; 101 final SmallValue four; 102 final ProjectivePoint.Immutable neutral; 103 private final IntegerFieldModuloP orderField; 104 ECOperations(IntegerModuloP b, IntegerFieldModuloP orderField)105 public ECOperations(IntegerModuloP b, IntegerFieldModuloP orderField) { 106 this.b = b.fixed(); 107 this.orderField = orderField; 108 109 this.one = b.getField().getSmallValue(1); 110 this.two = b.getField().getSmallValue(2); 111 this.three = b.getField().getSmallValue(3); 112 this.four = b.getField().getSmallValue(4); 113 114 IntegerFieldModuloP field = b.getField(); 115 this.neutral = new ProjectivePoint.Immutable(field.get0(), 116 field.get1(), field.get0()); 117 } 118 getField()119 public IntegerFieldModuloP getField() { 120 return b.getField(); 121 } getOrderField()122 public IntegerFieldModuloP getOrderField() { 123 return orderField; 124 } 125 getNeutral()126 protected ProjectivePoint.Immutable getNeutral() { 127 return neutral; 128 } 129 isNeutral(Point p)130 public boolean isNeutral(Point p) { 131 ProjectivePoint<?> pp = (ProjectivePoint<?>) p; 132 133 IntegerModuloP z = pp.getZ(); 134 135 IntegerFieldModuloP field = z.getField(); 136 int byteLength = (field.getSize().bitLength() + 7) / 8; 137 byte[] zBytes = z.asByteArray(byteLength); 138 return allZero(zBytes); 139 } 140 seedToScalar(byte[] seedBytes)141 byte[] seedToScalar(byte[] seedBytes) 142 throws IntermediateValueException { 143 144 // Produce a nonce from the seed using FIPS 186-4,section B.5.1: 145 // Per-Message Secret Number Generation Using Extra Random Bits 146 // or 147 // Produce a scalar from the seed using FIPS 186-4, section B.4.1: 148 // Key Pair Generation Using Extra Random Bits 149 150 // To keep the implementation simple, sample in the range [0,n) 151 // and throw IntermediateValueException in the (unlikely) event 152 // that the result is 0. 153 154 // Get 64 extra bits and reduce in to the nonce 155 int seedBits = orderField.getSize().bitLength() + 64; 156 if (seedBytes.length * 8 < seedBits) { 157 throw new ProviderException("Incorrect seed length: " + 158 seedBytes.length * 8 + " < " + seedBits); 159 } 160 161 // input conversion only works on byte boundaries 162 // clear high-order bits of last byte so they don't influence nonce 163 int lastByteBits = seedBits % 8; 164 if (lastByteBits != 0) { 165 int lastByteIndex = seedBits / 8; 166 byte mask = (byte) (0xFF >>> (8 - lastByteBits)); 167 seedBytes[lastByteIndex] &= mask; 168 } 169 170 int seedLength = (seedBits + 7) / 8; 171 IntegerModuloP scalarElem = 172 orderField.getElement(seedBytes, 0, seedLength, (byte) 0); 173 int scalarLength = (orderField.getSize().bitLength() + 7) / 8; 174 byte[] scalarArr = new byte[scalarLength]; 175 scalarElem.asByteArray(scalarArr); 176 if (ECOperations.allZero(scalarArr)) { 177 throw new IntermediateValueException(); 178 } 179 return scalarArr; 180 } 181 182 /* 183 * Compare all values in the array to 0 without branching on any value 184 * 185 */ allZero(byte[] arr)186 public static boolean allZero(byte[] arr) { 187 byte acc = 0; 188 for (int i = 0; i < arr.length; i++) { 189 acc |= arr[i]; 190 } 191 return acc == 0; 192 } 193 194 /* 195 * 4-bit branchless array lookup for projective points. 196 */ lookup4(ProjectivePoint.Immutable[] arr, int index, ProjectivePoint.Mutable result, IntegerModuloP zero)197 private void lookup4(ProjectivePoint.Immutable[] arr, int index, 198 ProjectivePoint.Mutable result, IntegerModuloP zero) { 199 200 for (int i = 0; i < 16; i++) { 201 int xor = index ^ i; 202 int bit3 = (xor & 0x8) >>> 3; 203 int bit2 = (xor & 0x4) >>> 2; 204 int bit1 = (xor & 0x2) >>> 1; 205 int bit0 = (xor & 0x1); 206 int inverse = bit0 | bit1 | bit2 | bit3; 207 int set = 1 - inverse; 208 209 ProjectivePoint.Immutable pi = arr[i]; 210 result.conditionalSet(pi, set); 211 } 212 } 213 double4(ProjectivePoint.Mutable p, MutableIntegerModuloP t0, MutableIntegerModuloP t1, MutableIntegerModuloP t2, MutableIntegerModuloP t3, MutableIntegerModuloP t4)214 private void double4(ProjectivePoint.Mutable p, MutableIntegerModuloP t0, 215 MutableIntegerModuloP t1, MutableIntegerModuloP t2, 216 MutableIntegerModuloP t3, MutableIntegerModuloP t4) { 217 218 for (int i = 0; i < 4; i++) { 219 setDouble(p, t0, t1, t2, t3, t4); 220 } 221 } 222 223 /** 224 * Multiply an affine point by a scalar and return the result as a mutable 225 * point. 226 * 227 * @param affineP the point 228 * @param s the scalar as a little-endian array 229 * @return the product 230 */ multiply(AffinePoint affineP, byte[] s)231 public MutablePoint multiply(AffinePoint affineP, byte[] s) { 232 233 // 4-bit windowed multiply with branchless lookup. 234 // The mixed addition is faster, so it is used to construct the array 235 // at the beginning of the operation. 236 237 IntegerFieldModuloP field = affineP.getX().getField(); 238 ImmutableIntegerModuloP zero = field.get0(); 239 // temporaries 240 MutableIntegerModuloP t0 = zero.mutable(); 241 MutableIntegerModuloP t1 = zero.mutable(); 242 MutableIntegerModuloP t2 = zero.mutable(); 243 MutableIntegerModuloP t3 = zero.mutable(); 244 MutableIntegerModuloP t4 = zero.mutable(); 245 246 ProjectivePoint.Mutable result = new ProjectivePoint.Mutable(field); 247 result.getY().setValue(field.get1().mutable()); 248 249 ProjectivePoint.Immutable[] pointMultiples = 250 new ProjectivePoint.Immutable[16]; 251 // 0P is neutral---same as initial result value 252 pointMultiples[0] = result.fixed(); 253 254 ProjectivePoint.Mutable ps = new ProjectivePoint.Mutable(field); 255 ps.setValue(affineP); 256 // 1P = P 257 pointMultiples[1] = ps.fixed(); 258 259 // the rest are calculated using mixed point addition 260 for (int i = 2; i < 16; i++) { 261 setSum(ps, affineP, t0, t1, t2, t3, t4); 262 pointMultiples[i] = ps.fixed(); 263 } 264 265 ProjectivePoint.Mutable lookupResult = ps.mutable(); 266 267 for (int i = s.length - 1; i >= 0; i--) { 268 269 double4(result, t0, t1, t2, t3, t4); 270 271 int high = (0xFF & s[i]) >>> 4; 272 lookup4(pointMultiples, high, lookupResult, zero); 273 setSum(result, lookupResult, t0, t1, t2, t3, t4); 274 275 double4(result, t0, t1, t2, t3, t4); 276 277 int low = 0xF & s[i]; 278 lookup4(pointMultiples, low, lookupResult, zero); 279 setSum(result, lookupResult, t0, t1, t2, t3, t4); 280 } 281 282 return result; 283 284 } 285 286 /* 287 * Point double 288 */ setDouble(ProjectivePoint.Mutable p, MutableIntegerModuloP t0, MutableIntegerModuloP t1, MutableIntegerModuloP t2, MutableIntegerModuloP t3, MutableIntegerModuloP t4)289 private void setDouble(ProjectivePoint.Mutable p, MutableIntegerModuloP t0, 290 MutableIntegerModuloP t1, MutableIntegerModuloP t2, 291 MutableIntegerModuloP t3, MutableIntegerModuloP t4) { 292 293 t0.setValue(p.getX()).setSquare(); 294 t1.setValue(p.getY()).setSquare(); 295 t2.setValue(p.getZ()).setSquare(); 296 t3.setValue(p.getX()).setProduct(p.getY()); 297 t4.setValue(p.getY()).setProduct(p.getZ()); 298 299 t3.setSum(t3); 300 p.getZ().setProduct(p.getX()); 301 302 p.getZ().setProduct(two); 303 304 p.getY().setValue(t2).setProduct(b); 305 p.getY().setDifference(p.getZ()); 306 307 p.getX().setValue(p.getY()).setProduct(two); 308 p.getY().setSum(p.getX()); 309 p.getY().setReduced(); 310 p.getX().setValue(t1).setDifference(p.getY()); 311 312 p.getY().setSum(t1); 313 p.getY().setProduct(p.getX()); 314 p.getX().setProduct(t3); 315 316 t3.setValue(t2).setProduct(two); 317 t2.setSum(t3); 318 p.getZ().setProduct(b); 319 320 t2.setReduced(); 321 p.getZ().setDifference(t2); 322 p.getZ().setDifference(t0); 323 t3.setValue(p.getZ()).setProduct(two); 324 p.getZ().setReduced(); 325 p.getZ().setSum(t3); 326 t0.setProduct(three); 327 328 t0.setDifference(t2); 329 t0.setProduct(p.getZ()); 330 p.getY().setSum(t0); 331 332 t4.setSum(t4); 333 p.getZ().setProduct(t4); 334 335 p.getX().setDifference(p.getZ()); 336 p.getZ().setValue(t4).setProduct(t1); 337 338 p.getZ().setProduct(four); 339 340 } 341 342 /* 343 * Mixed point addition. This method constructs new temporaries each time 344 * it is called. For better efficiency, the method that reuses temporaries 345 * should be used if more than one sum will be computed. 346 */ setSum(MutablePoint p, AffinePoint p2)347 public void setSum(MutablePoint p, AffinePoint p2) { 348 349 IntegerModuloP zero = p.getField().get0(); 350 MutableIntegerModuloP t0 = zero.mutable(); 351 MutableIntegerModuloP t1 = zero.mutable(); 352 MutableIntegerModuloP t2 = zero.mutable(); 353 MutableIntegerModuloP t3 = zero.mutable(); 354 MutableIntegerModuloP t4 = zero.mutable(); 355 setSum((ProjectivePoint.Mutable) p, p2, t0, t1, t2, t3, t4); 356 357 } 358 359 /* 360 * Mixed point addition 361 */ setSum(ProjectivePoint.Mutable p, AffinePoint p2, MutableIntegerModuloP t0, MutableIntegerModuloP t1, MutableIntegerModuloP t2, MutableIntegerModuloP t3, MutableIntegerModuloP t4)362 private void setSum(ProjectivePoint.Mutable p, AffinePoint p2, 363 MutableIntegerModuloP t0, MutableIntegerModuloP t1, 364 MutableIntegerModuloP t2, MutableIntegerModuloP t3, 365 MutableIntegerModuloP t4) { 366 367 t0.setValue(p.getX()).setProduct(p2.getX()); 368 t1.setValue(p.getY()).setProduct(p2.getY()); 369 t3.setValue(p2.getX()).setSum(p2.getY()); 370 t4.setValue(p.getX()).setSum(p.getY()); 371 p.getX().setReduced(); 372 t3.setProduct(t4); 373 t4.setValue(t0).setSum(t1); 374 375 t3.setDifference(t4); 376 t4.setValue(p2.getY()).setProduct(p.getZ()); 377 t4.setSum(p.getY()); 378 379 p.getY().setValue(p2.getX()).setProduct(p.getZ()); 380 p.getY().setSum(p.getX()); 381 t2.setValue(p.getZ()); 382 p.getZ().setProduct(b); 383 384 p.getX().setValue(p.getY()).setDifference(p.getZ()); 385 p.getX().setReduced(); 386 p.getZ().setValue(p.getX()).setProduct(two); 387 p.getX().setSum(p.getZ()); 388 389 p.getZ().setValue(t1).setDifference(p.getX()); 390 p.getX().setSum(t1); 391 p.getY().setProduct(b); 392 393 t1.setValue(t2).setProduct(two); 394 t2.setSum(t1); 395 t2.setReduced(); 396 p.getY().setDifference(t2); 397 398 p.getY().setDifference(t0); 399 p.getY().setReduced(); 400 t1.setValue(p.getY()).setProduct(two); 401 p.getY().setSum(t1); 402 403 t1.setValue(t0).setProduct(two); 404 t0.setSum(t1); 405 t0.setDifference(t2); 406 407 t1.setValue(t4).setProduct(p.getY()); 408 t2.setValue(t0).setProduct(p.getY()); 409 p.getY().setValue(p.getX()).setProduct(p.getZ()); 410 411 p.getY().setSum(t2); 412 p.getX().setProduct(t3); 413 p.getX().setDifference(t1); 414 415 p.getZ().setProduct(t4); 416 t1.setValue(t3).setProduct(t0); 417 p.getZ().setSum(t1); 418 419 } 420 421 /* 422 * Projective point addition 423 */ setSum(ProjectivePoint.Mutable p, ProjectivePoint.Mutable p2, MutableIntegerModuloP t0, MutableIntegerModuloP t1, MutableIntegerModuloP t2, MutableIntegerModuloP t3, MutableIntegerModuloP t4)424 private void setSum(ProjectivePoint.Mutable p, ProjectivePoint.Mutable p2, 425 MutableIntegerModuloP t0, MutableIntegerModuloP t1, 426 MutableIntegerModuloP t2, MutableIntegerModuloP t3, 427 MutableIntegerModuloP t4) { 428 429 t0.setValue(p.getX()).setProduct(p2.getX()); 430 t1.setValue(p.getY()).setProduct(p2.getY()); 431 t2.setValue(p.getZ()).setProduct(p2.getZ()); 432 433 t3.setValue(p.getX()).setSum(p.getY()); 434 t4.setValue(p2.getX()).setSum(p2.getY()); 435 t3.setProduct(t4); 436 437 t4.setValue(t0).setSum(t1); 438 t3.setDifference(t4); 439 t4.setValue(p.getY()).setSum(p.getZ()); 440 441 p.getY().setValue(p2.getY()).setSum(p2.getZ()); 442 t4.setProduct(p.getY()); 443 p.getY().setValue(t1).setSum(t2); 444 445 t4.setDifference(p.getY()); 446 p.getX().setSum(p.getZ()); 447 p.getY().setValue(p2.getX()).setSum(p2.getZ()); 448 449 p.getX().setProduct(p.getY()); 450 p.getY().setValue(t0).setSum(t2); 451 p.getY().setAdditiveInverse().setSum(p.getX()); 452 p.getY().setReduced(); 453 454 p.getZ().setValue(t2).setProduct(b); 455 p.getX().setValue(p.getY()).setDifference(p.getZ()); 456 p.getZ().setValue(p.getX()).setProduct(two); 457 458 p.getX().setSum(p.getZ()); 459 p.getX().setReduced(); 460 p.getZ().setValue(t1).setDifference(p.getX()); 461 p.getX().setSum(t1); 462 463 p.getY().setProduct(b); 464 t1.setValue(t2).setSum(t2); 465 t2.setSum(t1); 466 t2.setReduced(); 467 468 p.getY().setDifference(t2); 469 p.getY().setDifference(t0); 470 p.getY().setReduced(); 471 t1.setValue(p.getY()).setSum(p.getY()); 472 473 p.getY().setSum(t1); 474 t1.setValue(t0).setProduct(two); 475 t0.setSum(t1); 476 477 t0.setDifference(t2); 478 t1.setValue(t4).setProduct(p.getY()); 479 t2.setValue(t0).setProduct(p.getY()); 480 481 p.getY().setValue(p.getX()).setProduct(p.getZ()); 482 p.getY().setSum(t2); 483 p.getX().setProduct(t3); 484 485 p.getX().setDifference(t1); 486 p.getZ().setProduct(t4); 487 t1.setValue(t3).setProduct(t0); 488 489 p.getZ().setSum(t1); 490 491 } 492 } 493 494