1 /*
2  * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved.
3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4  *
5  * This code is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 only, as
7  * published by the Free Software Foundation.  Oracle designates this
8  * particular file as subject to the "Classpath" exception as provided
9  * by Oracle in the LICENSE file that accompanied this code.
10  *
11  * This code is distributed in the hope that it will be useful, but WITHOUT
12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
14  * version 2 for more details (a copy is included in the LICENSE file that
15  * accompanied this code).
16  *
17  * You should have received a copy of the GNU General Public License version
18  * 2 along with this work; if not, write to the Free Software Foundation,
19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20  *
21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22  * or visit www.oracle.com if you need additional information or have any
23  * questions.
24  */
25 package sun.security.ec.ed;
26 
27 import sun.security.ec.point.AffinePoint;
28 import sun.security.ec.point.Point;
29 import sun.security.util.ArrayUtil;
30 import sun.security.util.math.IntegerFieldModuloP;
31 import sun.security.util.math.IntegerModuloP;
32 import sun.security.util.math.MutableIntegerModuloP;
33 
34 import java.math.BigInteger;
35 import java.security.InvalidKeyException;
36 import java.security.NoSuchAlgorithmException;
37 import java.security.SecureRandom;
38 import java.security.SignatureException;
39 import java.security.spec.EdDSAParameterSpec;
40 import java.security.spec.EdECPoint;
41 import java.util.Arrays;
42 import java.util.function.Function;
43 
44 /*
45  * A class containing the operations of the EdDSA signature scheme. The
46  * parameters include an object that performs the elliptic curve point
47  * arithmetic, and EdDSAOperations uses this object to construct the signing
48  * and verification operations.
49  */
50 public class EdDSAOperations {
51 
52     private final EdDSAParameters params;
53 
EdDSAOperations(EdDSAParameters params)54     public EdDSAOperations(EdDSAParameters params)
55         throws NoSuchAlgorithmException {
56 
57         this.params = params;
58     }
59 
getParameters()60     public EdDSAParameters getParameters() {
61         return params;
62     }
63 
generatePrivate(SecureRandom random)64     public byte[] generatePrivate(SecureRandom random) {
65         byte[] result = new byte[params.getKeyLength()];
66         random.nextBytes(result);
67         return result;
68     }
69 
computePublic(byte[] privateKey)70     public EdECPoint computePublic(byte[] privateKey) {
71         byte[] privateKeyHash = params.digest(privateKey);
72         int byteLength = privateKeyHash.length / 2;
73         byte[] s = Arrays.copyOf(privateKeyHash, byteLength);
74         prune(s);
75         IntegerModuloP fieldS = params.getOrderField().getElement(s);
76         fieldS.asByteArray(s);
77         Point A = params.getEdOperations().basePointMultiply(s);
78         return asEdECPoint(A.asAffine());
79     }
80 
asEdECPoint(AffinePoint p)81     private static EdECPoint asEdECPoint(AffinePoint p) {
82         return new EdECPoint(p.getX().asBigInteger().testBit(0),
83             p.getY().asBigInteger());
84     }
85 
sign(EdDSAParameterSpec sigParams, byte[] privateKey, byte[] message)86     public byte[] sign(EdDSAParameterSpec sigParams, byte[] privateKey,
87                        byte[] message) {
88 
89         byte[] privateKeyHash = params.digest(privateKey);
90 
91         int byteLength = privateKeyHash.length / 2;
92         byte[] s = Arrays.copyOf(privateKeyHash, byteLength);
93         prune(s);
94         IntegerModuloP sElem = params.getOrderField().getElement(s);
95         sElem.asByteArray(s);
96         Point A = params.getEdOperations().basePointMultiply(s);
97         byte[] prefix = Arrays.copyOfRange(privateKeyHash,
98             privateKeyHash.length / 2, privateKeyHash.length);
99         byte[] dom = params.dom(sigParams);
100         byte[] r = params.digest(dom, prefix, message);
101 
102         // reduce r modulo the order
103         IntegerModuloP fieldR = params.getOrderField().getElement(r);
104         r = new byte[params.getKeyLength()];
105         fieldR.asByteArray(r);
106 
107         Point R = params.getEdOperations().basePointMultiply(r);
108 
109         byte[] encodedR = encode(byteLength, R);
110         byte[] encodedA = encode(byteLength, A);
111         byte[] k = params.digest(dom, encodedR, encodedA, message);
112 
113         // S computation is in group-order field
114         IntegerFieldModuloP subField = params.getOrderField();
115         IntegerModuloP kElem = subField.getElement(k);
116         IntegerModuloP rElem = subField.getElement(r);
117         MutableIntegerModuloP S = kElem.mutable().setProduct(sElem);
118         S.setSum(rElem);
119         // need to be reduced before output conversion
120         S.setReduced();
121         byte[] sArr = S.asByteArray(byteLength);
122         byte[] rArr = encode(byteLength, R);
123 
124         byte[] result = new byte[byteLength * 2];
125         System.arraycopy(rArr, 0, result, 0, byteLength);
126         System.arraycopy(sArr, 0, result, byteLength, byteLength);
127         return result;
128     }
129 
verify(EdDSAParameterSpec sigParams, AffinePoint affineA, byte[] publicKey, byte[] message, byte[] signature)130     public boolean verify(EdDSAParameterSpec sigParams, AffinePoint affineA,
131                           byte[] publicKey, byte[] message, byte[] signature)
132         throws SignatureException {
133 
134         if (signature == null) {
135             throw new SignatureException("signature was null");
136         }
137         byte[] encR = Arrays.copyOf(signature, signature.length / 2);
138         byte[] encS = Arrays.copyOfRange(signature, signature.length / 2,
139             signature.length);
140 
141         // reject s if it is too large
142         ArrayUtil.reverse(encS);
143         BigInteger bigS = new BigInteger(1, encS);
144         if (bigS.compareTo(params.getOrderField().getSize()) >= 0) {
145             throw new SignatureException("s is too large");
146         }
147         ArrayUtil.reverse(encS);
148 
149         byte[] dom = params.dom(sigParams);
150         AffinePoint affineR = decodeAffinePoint(SignatureException::new, encR);
151         byte[] k = params.digest(dom, encR, publicKey, message);
152         // reduce k to improve performance of multiply
153         IntegerFieldModuloP subField = params.getOrderField();
154         IntegerModuloP kElem = subField.getElement(k);
155         k = kElem.asByteArray(k.length / 2);
156 
157         Point pointR = params.getEdOperations().of(affineR);
158         Point pointA = params.getEdOperations().of(affineA);
159 
160         EdECOperations edOps = params.getEdOperations();
161         Point lhs = edOps.basePointMultiply(encS);
162         Point rhs = edOps.setSum(edOps.setProduct(pointA.mutable(), k),
163             pointR.mutable());
164 
165         return lhs.affineEquals(rhs);
166     }
167 
verify(EdDSAParameterSpec sigParams, byte[] publicKey, byte[] message, byte[] signature)168     public boolean verify(EdDSAParameterSpec sigParams, byte[] publicKey,
169                           byte[] message, byte[] signature)
170         throws InvalidKeyException, SignatureException {
171 
172         AffinePoint affineA = decodeAffinePoint(InvalidKeyException::new,
173             publicKey);
174         return verify(sigParams, affineA, publicKey, message, signature);
175     }
176 
177     public
178     <T extends Throwable>
decodeAffinePoint(Function<String, T> exception, byte[] arr)179     AffinePoint decodeAffinePoint(Function<String, T> exception, byte[] arr)
180     throws T {
181 
182         if (arr.length != params.getKeyLength()) {
183             throw exception.apply("incorrect length");
184         }
185 
186         arr = arr.clone();
187         int xLSB = (0xFF & arr[arr.length - 1]) >>> 7;
188         arr[arr.length - 1] &= 0x7F;
189         int yLength = (params.getBits() + 7) >> 3;
190         IntegerModuloP y =
191             params.getField().getElement(arr, 0, yLength, (byte) 0);
192         // reject non-canonical y values
193         ArrayUtil.reverse(arr);
194         BigInteger bigY = new BigInteger(1, arr);
195         if (bigY.compareTo(params.getField().getSize()) >= 0) {
196             throw exception.apply("y value is too large");
197         }
198         return params.getEdOperations().decodeAffinePoint(exception, xLSB, y);
199     }
200 
201     public
202     <T extends Throwable>
decodeAffinePoint(Function<String, T> exception, EdECPoint point)203     AffinePoint decodeAffinePoint(Function<String, T> exception,
204                                   EdECPoint point)
205         throws T {
206 
207         // reject non-canonical y values
208         if (point.getY().compareTo(params.getField().getSize()) >= 0) {
209             throw exception.apply("y value is too large");
210         }
211 
212         int xLSB = point.isXOdd() ? 1 : 0;
213         IntegerModuloP y = params.getField().getElement(point.getY());
214         return params.getEdOperations().decodeAffinePoint(exception, xLSB, y);
215     }
216 
217     /**
218      * Mask off the high order bits of an encoded integer in an array. The
219      * array is modified in place.
220      *
221      * @param arr an array containing an encoded integer
222      * @param bits the number of bits to keep
223      * @return the number, in range [0,8], of bits kept in the highest byte
224      */
maskHighOrder(byte[] arr, int bits)225     private static int maskHighOrder(byte[] arr, int bits) {
226 
227         int lastByteIndex = arr.length - 1;
228         int bitsDiff = arr.length * 8 - bits;
229         int highBits = 8 - bitsDiff;
230         byte msbMaskOff = (byte) ((1 << highBits) - 1);
231         arr[lastByteIndex] &= msbMaskOff;
232 
233         return highBits;
234     }
235 
236     /**
237      * Prune an encoded scalar value by modifying it in place. The extra
238      * high-order bits are masked off, the highest valid bit it set, and the
239      * number is rounded down to a multiple of the co-factor.
240      *
241      * @param k an encoded scalar value
242      * @param bits the number of bits in the scalar
243      * @param logCofactor the base-2 logarithm of the co-factor
244      */
prune(byte[] k, int bits, int logCofactor)245     private static void prune(byte[] k, int bits, int logCofactor) {
246 
247         int lastByteIndex = k.length - 1;
248 
249         // mask off unused high-order bits
250         int highBits = maskHighOrder(k, bits);
251 
252         // set the highest bit
253         if (highBits == 0) {
254             k[lastByteIndex - 1] |= 0x80;
255         } else {
256             byte msbMaskOn = (byte) (1 << (highBits - 1));
257             k[lastByteIndex] |= msbMaskOn;
258         }
259 
260         // round down to a multiple of the co-factor
261         byte lsbMaskOff = (byte) (0xFF << logCofactor);
262         k[0] &= lsbMaskOff;
263     }
264 
prune(byte[] arr)265     void prune(byte[] arr) {
266         prune(arr, params.getBits(), params.getLogCofactor());
267     }
268 
encode(int length, Point p)269     private static byte[] encode(int length, Point p) {
270         return encode(length, p.asAffine());
271     }
272 
encode(int length, AffinePoint p)273     private static byte[] encode(int length, AffinePoint p) {
274         byte[] result = p.getY().asByteArray(length);
275         int xLSB = p.getX().asByteArray(1)[0] & 0x01;
276         result[result.length - 1] |= (xLSB << 7);
277         return result;
278     }
279 }
280