1 /*
2  * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved.
3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4  *
5  * This code is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 only, as
7  * published by the Free Software Foundation.  Oracle designates this
8  * particular file as subject to the "Classpath" exception as provided
9  * by Oracle in the LICENSE file that accompanied this code.
10  *
11  * This code is distributed in the hope that it will be useful, but WITHOUT
12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
14  * version 2 for more details (a copy is included in the LICENSE file that
15  * accompanied this code).
16  *
17  * You should have received a copy of the GNU General Public License version
18  * 2 along with this work; if not, write to the Free Software Foundation,
19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20  *
21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22  * or visit www.oracle.com if you need additional information or have any
23  * questions.
24  */
25 
26 package sun.security.util.math.intpoly;
27 
28 import java.math.BigInteger;
29 import java.nio.ByteBuffer;
30 
31 /**
32  * The field of integers modulo a binomial prime. This is a general-purpose
33  * field implementation, that is much slower than more specialized classes
34  * like IntegerPolynomial25519. It is suitable when only a small number of
35  * arithmetic operations are required in some field. For example, this class
36  * can be used for operations on scalars/exponents in signature operations.
37  *
38  * This class may only be used for primes of the form 2^a + b.
39  */
40 
41 public class IntegerPolynomialModBinP extends IntegerPolynomial {
42 
43     private final long[] reduceLimbs;
44     private final int bitOffset;
45     private final int limbMask;
46     private final int rightBitOffset;
47     private final int power;
48 
IntegerPolynomialModBinP(int bitsPerLimb, int numLimbs, int power, BigInteger subtrahend)49     public IntegerPolynomialModBinP(int bitsPerLimb,
50                                     int numLimbs,
51                                     int power,
52                                     BigInteger subtrahend) {
53         super(bitsPerLimb, numLimbs, 1,
54             BigInteger.valueOf(2).pow(power).subtract(subtrahend));
55 
56         boolean negate = false;
57         if (subtrahend.compareTo(BigInteger.ZERO) < 0) {
58             negate = true;
59             subtrahend = subtrahend.negate();
60         }
61         int reduceLimbsLength = subtrahend.bitLength() / bitsPerLimb + 1;
62         reduceLimbs = new long[reduceLimbsLength];
63         ImmutableElement reduceElem = getElement(subtrahend);
64         if (negate) {
65             reduceElem = reduceElem.additiveInverse();
66         }
67         System.arraycopy(reduceElem.limbs, 0, reduceLimbs, 0,
68             reduceLimbs.length);
69 
70         // begin test code
71         System.out.println("reduce limbs:");
72         for (int i = 0; i < reduceLimbs.length; i++) {
73             System.out.println(i + ":" + reduceLimbs[i]);
74         }
75         // end test code
76 
77         this.power = power;
78         this.bitOffset = numLimbs * bitsPerLimb - power;
79         this.limbMask = -1 >>> (64 - bitsPerLimb);
80         this.rightBitOffset = bitsPerLimb - bitOffset;
81     }
82 
83     @Override
finalCarryReduceLast(long[] limbs)84     protected void finalCarryReduceLast(long[] limbs) {
85 
86         int extraBits = bitsPerLimb * numLimbs - power;
87         int highBits = bitsPerLimb - extraBits;
88         long c = limbs[numLimbs - 1] >> highBits;
89         limbs[numLimbs - 1] -= c << highBits;
90         for (int j = 0; j < reduceLimbs.length; j++) {
91             int reduceBits = power + extraBits - j * bitsPerLimb;
92             modReduceInBits(limbs, numLimbs, reduceBits, c * reduceLimbs[j]);
93         }
94     }
95 
96 
97     /**
98      * Allow more general (and slower) input conversion that takes a large
99      * value and reduces it.
100      */
101     @Override
getElement(byte[] v, int offset, int length, byte highByte)102     public ImmutableElement getElement(byte[] v, int offset, int length,
103                                        byte highByte) {
104 
105         long[] result = new long[numLimbs];
106         int numHighBits = 32 - Integer.numberOfLeadingZeros(highByte);
107         int numBits = 8 * length + numHighBits;
108         int requiredLimbs = (numBits + bitsPerLimb - 1) / bitsPerLimb;
109         if (requiredLimbs > numLimbs) {
110             long[] temp = new long[requiredLimbs];
111             encode(v, offset, length, highByte, temp);
112             // encode does a full carry/reduce
113             System.arraycopy(temp, 0, result, 0, result.length);
114         } else {
115             encode(v, offset, length, highByte, result);
116         }
117 
118         return new ImmutableElement(result, 0);
119     }
120 
121     /**
122      * Multiply a and b, and store the result in c. Requires that
123      * a.length == b.length == numLimbs and c.length >= 2 * numLimbs - 1.
124      * It is allowed for a and b to be the same array.
125      */
multOnly(long[] a, long[] b, long[] c)126     private void multOnly(long[] a, long[] b, long[] c) {
127         for (int i = 0; i < numLimbs; i++) {
128             for (int j = 0; j < numLimbs; j++) {
129                 c[i + j] += a[i] * b[j];
130             }
131         }
132     }
133 
134     @Override
mult(long[] a, long[] b, long[] r)135     protected void mult(long[] a, long[] b, long[] r) {
136 
137         long[] c = new long[2 * numLimbs];
138         multOnly(a, b, c);
139         carryReduce(c, r);
140     }
141 
modReduceInBits(long[] limbs, int index, int bits, long x)142     private void modReduceInBits(long[] limbs, int index, int bits, long x) {
143 
144         if (bits % bitsPerLimb == 0) {
145             int pos = bits / bitsPerLimb;
146             limbs[index - pos] += x;
147         }
148         else {
149             int secondPos = bits / (bitsPerLimb);
150             int bitOffset = (secondPos + 1) * bitsPerLimb - bits;
151             int rightBitOffset = bitsPerLimb - bitOffset;
152             limbs[index - (secondPos + 1)] += (x << bitOffset) & limbMask;
153             limbs[index - secondPos] += x >> rightBitOffset;
154         }
155     }
156 
reduceIn(long[] c, long v, int i)157     protected void reduceIn(long[] c, long v, int i) {
158 
159         for (int j = 0; j < reduceLimbs.length; j++) {
160             modReduceInBits(c, i, power - bitsPerLimb * j, reduceLimbs[j] * v);
161         }
162     }
163 
carryReduce(long[] c, long[] r)164     private void carryReduce(long[] c, long[] r) {
165 
166         // full carry to prevent overflow during reduce
167         carry(c);
168         // Reduce in from all high positions
169         for (int i = c.length - 1; i >= numLimbs; i--) {
170             reduceIn(c, c[i], i);
171             c[i] = 0;
172         }
173         // carry on lower positions that possibly carries out one position
174         carry(c, 0, numLimbs);
175         // reduce in a single position
176         reduceIn(c, c[numLimbs], numLimbs);
177         c[numLimbs] = 0;
178         // final carry
179         carry(c, 0, numLimbs - 1);
180         System.arraycopy(c, 0, r, 0, r.length);
181     }
182 
183     @Override
reduce(long[] a)184     protected void reduce(long[] a) {
185         // TODO: optimize this
186         long[] c = new long[a.length + 2];
187         System.arraycopy(a, 0, c, 0, a.length);
188         carryReduce(c, a);
189     }
190 
191     @Override
square(long[] a, long[] r)192     protected void square(long[] a, long[] r) {
193 
194         long[] c = new long[2 * numLimbs];
195         for (int i = 0; i < numLimbs; i++) {
196             c[2 * i] += a[i] * a[i];
197             for (int j = i + 1; j < numLimbs; j++) {
198                 c[i + j] += 2 * a[i] * a[j];
199             }
200         }
201 
202         carryReduce(c, r);
203 
204     }
205 
206     /**
207      * The field of integers modulo the order of the Curve25519 subgroup
208      */
209     public static class Curve25519OrderField extends IntegerPolynomialModBinP {
210 
Curve25519OrderField()211         public Curve25519OrderField() {
212             super(26, 10, 252,
213                 new BigInteger("-27742317777372353535851937790883648493"));
214         }
215     }
216 
217     /**
218      * The field of integers modulo the order of the Curve448 subgroup
219      */
220     public static class Curve448OrderField extends IntegerPolynomialModBinP {
221 
Curve448OrderField()222         public Curve448OrderField() {
223             super(28, 16, 446,
224                 new BigInteger("138180668098951153520073867485154268803366" +
225                     "92474882178609894547503885"));
226         }
227     }
228 }
229