1 /*
2  * Copyright (c) 2018, 2020, Oracle and/or its affiliates. All rights reserved.
3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4  *
5  * This code is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 only, as
7  * published by the Free Software Foundation.  Oracle designates this
8  * particular file as subject to the "Classpath" exception as provided
9  * by Oracle in the LICENSE file that accompanied this code.
10  *
11  * This code is distributed in the hope that it will be useful, but WITHOUT
12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
14  * version 2 for more details (a copy is included in the LICENSE file that
15  * accompanied this code).
16  *
17  * You should have received a copy of the GNU General Public License version
18  * 2 along with this work; if not, write to the Free Software Foundation,
19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20  *
21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22  * or visit www.oracle.com if you need additional information or have any
23  * questions.
24  */
25 
26 package sun.security.util.math.intpoly;
27 
28 import java.lang.invoke.MethodHandles;
29 import java.lang.invoke.VarHandle;
30 import java.math.BigInteger;
31 import java.nio.*;
32 
33 /**
34  * An IntegerFieldModuloP designed for use with the Poly1305 authenticator.
35  * The representation uses 5 signed long values.
36  */
37 
38 public class IntegerPolynomial1305 extends IntegerPolynomial {
39 
40     protected static final int SUBTRAHEND = 5;
41     protected static final int NUM_LIMBS = 5;
42     private static final int POWER = 130;
43     private static final int BITS_PER_LIMB = 26;
44     private static final BigInteger MODULUS
45         = TWO.pow(POWER).subtract(BigInteger.valueOf(SUBTRAHEND));
46 
IntegerPolynomial1305()47     public IntegerPolynomial1305() {
48         super(BITS_PER_LIMB, NUM_LIMBS, 1, MODULUS);
49     }
50 
mult(long[] a, long[] b, long[] r)51     protected void mult(long[] a, long[] b, long[] r) {
52 
53         // Use grade-school multiplication into primitives to avoid the
54         // temporary array allocation. This is equivalent to the following
55         // code:
56         //  long[] c = new long[2 * NUM_LIMBS - 1];
57         //  for(int i = 0; i < NUM_LIMBS; i++) {
58         //      for(int j - 0; j < NUM_LIMBS; j++) {
59         //          c[i + j] += a[i] * b[j]
60         //      }
61         //  }
62 
63         long c0 = (a[0] * b[0]);
64         long c1 = (a[0] * b[1]) + (a[1] * b[0]);
65         long c2 = (a[0] * b[2]) + (a[1] * b[1]) + (a[2] * b[0]);
66         long c3 = (a[0] * b[3]) + (a[1] * b[2]) + (a[2] * b[1]) + (a[3] * b[0]);
67         long c4 = (a[0] * b[4]) + (a[1] * b[3]) + (a[2] * b[2]) + (a[3] * b[1]) + (a[4] * b[0]);
68         long c5 = (a[1] * b[4]) + (a[2] * b[3]) + (a[3] * b[2]) + (a[4] * b[1]);
69         long c6 = (a[2] * b[4]) + (a[3] * b[3]) + (a[4] * b[2]);
70         long c7 = (a[3] * b[4]) + (a[4] * b[3]);
71         long c8 = (a[4] * b[4]);
72 
73         carryReduce(r, c0, c1, c2, c3, c4, c5, c6, c7, c8);
74     }
75 
carryReduce(long[] r, long c0, long c1, long c2, long c3, long c4, long c5, long c6, long c7, long c8)76     private void carryReduce(long[] r, long c0, long c1, long c2, long c3,
77                              long c4, long c5, long c6, long c7, long c8) {
78         //reduce(2, 2)
79         r[2] = c2 + (c7 * SUBTRAHEND);
80         c3 += (c8 * SUBTRAHEND);
81 
82         // carry(3, 2)
83         long carry3 = carryValue(c3);
84         r[3] = c3 - (carry3 << BITS_PER_LIMB);
85         c4 += carry3;
86 
87         long carry4 = carryValue(c4);
88         r[4] = c4 - (carry4 << BITS_PER_LIMB);
89         c5 += carry4;
90 
91         // reduce(0, 2)
92         r[0] = c0 + (c5 * SUBTRAHEND);
93         r[1] = c1 + (c6 * SUBTRAHEND);
94 
95         // carry(0, 4)
96         carry(r);
97     }
98 
99     @Override
square(long[] a, long[] r)100     protected void square(long[] a, long[] r) {
101         // Use grade-school multiplication with a simple squaring optimization.
102         // Multiply into primitives to avoid the temporary array allocation.
103         // This is equivalent to the following code:
104         //  long[] c = new long[2 * NUM_LIMBS - 1];
105         //  for(int i = 0; i < NUM_LIMBS; i++) {
106         //      c[2 * i] = a[i] * a[i];
107         //      for(int j = i + 1; j < NUM_LIMBS; j++) {
108         //          c[i + j] += 2 * a[i] * a[j]
109         //      }
110         //  }
111 
112         long c0 = (a[0] * a[0]);
113         long c1 = 2 * (a[0] * a[1]);
114         long c2 = 2 * (a[0] * a[2]) + (a[1] * a[1]);
115         long c3 = 2 * (a[0] * a[3] + a[1] * a[2]);
116         long c4 = 2 * (a[0] * a[4] + a[1] * a[3]) + (a[2] * a[2]);
117         long c5 = 2 * (a[1] * a[4] + a[2] * a[3]);
118         long c6 = 2 * (a[2] * a[4]) + (a[3] * a[3]);
119         long c7 = 2 * (a[3] * a[4]);
120         long c8 = (a[4] * a[4]);
121 
122         carryReduce(r, c0, c1, c2, c3, c4, c5, c6, c7, c8);
123     }
124 
125     @Override
encode(ByteBuffer buf, int length, byte highByte, long[] result)126     protected void encode(ByteBuffer buf, int length, byte highByte,
127                           long[] result) {
128         if (length == 16) {
129             long low = buf.getLong();
130             long high = buf.getLong();
131             encode(high, low, highByte, result);
132         } else {
133             super.encode(buf, length, highByte, result);
134         }
135     }
136 
encode(long high, long low, byte highByte, long[] result)137     protected void encode(long high, long low, byte highByte, long[] result) {
138         result[0] = low & 0x3FFFFFFL;
139         result[1] = (low >>> 26) & 0x3FFFFFFL;
140         result[2] = (low >>> 52) + ((high & 0x3FFFL) << 12);
141         result[3] = (high >>> 14) & 0x3FFFFFFL;
142         result[4] = (high >>> 40) + (highByte << 24L);
143     }
144 
145     private static final VarHandle AS_LONG_LE = MethodHandles
146         .byteArrayViewVarHandle(long[].class, ByteOrder.LITTLE_ENDIAN);
147 
encode(byte[] v, int offset, int length, byte highByte, long[] result)148     protected void encode(byte[] v, int offset, int length, byte highByte,
149                           long[] result) {
150         if (length == 16) {
151             long low = (long) AS_LONG_LE.get(v, offset);
152             long high = (long) AS_LONG_LE.get(v, offset + 8);
153             encode(high, low, highByte, result);
154         } else {
155             super.encode(v, offset, length, highByte, result);
156         }
157     }
158 
159     @Override
reduceIn(long[] limbs, long x, int index)160     protected void reduceIn(long[] limbs, long x, int index) {
161         // this only works when BITS_PER_LIMB * NUM_LIMBS = POWER exactly
162         long reducedValue = (x * SUBTRAHEND);
163         limbs[index - NUM_LIMBS] += reducedValue;
164     }
165 
166     @Override
finalCarryReduceLast(long[] limbs)167     protected void finalCarryReduceLast(long[] limbs) {
168         long carry = limbs[numLimbs - 1] >> bitsPerLimb;
169         limbs[numLimbs - 1] -= carry << bitsPerLimb;
170         reduceIn(limbs, carry, numLimbs);
171     }
172 
modReduce(long[] limbs, int start, int end)173     protected final void modReduce(long[] limbs, int start, int end) {
174 
175         for (int i = start; i < end; i++) {
176             reduceIn(limbs, limbs[i], i);
177             limbs[i] = 0;
178         }
179     }
180 
modReduce(long[] limbs)181     protected void modReduce(long[] limbs) {
182 
183         modReduce(limbs, NUM_LIMBS, NUM_LIMBS - 1);
184     }
185 
186     @Override
carryValue(long x)187     protected long carryValue(long x) {
188         // This representation has plenty of extra space, so we can afford to
189         // do a simplified carry operation that is more time-efficient.
190 
191         return x >> BITS_PER_LIMB;
192     }
193 
194     @Override
postEncodeCarry(long[] v)195     protected void postEncodeCarry(long[] v) {
196         // not needed because carry is unsigned
197     }
198 
199     @Override
reduce(long[] limbs)200     protected void reduce(long[] limbs) {
201         long carry3 = carryOut(limbs, 3);
202         long new4 = carry3 + limbs[4];
203 
204         long carry4 = carryValue(new4);
205         limbs[4] = new4 - (carry4 << BITS_PER_LIMB);
206 
207         reduceIn(limbs, carry4, 5);
208         carry(limbs);
209     }
210 
211 }
212 
213