1 /*
2  * Copyright (c) 2002, 2020, Oracle and/or its affiliates. All rights reserved.
3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4  *
5  * This code is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 only, as
7  * published by the Free Software Foundation.  Oracle designates this
8  * particular file as subject to the "Classpath" exception as provided
9  * by Oracle in the LICENSE file that accompanied this code.
10  *
11  * This code is distributed in the hope that it will be useful, but WITHOUT
12  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
14  * version 2 for more details (a copy is included in the LICENSE file that
15  * accompanied this code).
16  *
17  * You should have received a copy of the GNU General Public License version
18  * 2 along with this work; if not, write to the Free Software Foundation,
19  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
20  *
21  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
22  * or visit www.oracle.com if you need additional information or have any
23  * questions.
24  */
25 
26 /* $Id: Rijndael.java,v 1.6 2000/02/10 01:31:41 gelderen Exp $
27  *
28  * Copyright (C) 1995-2000 The Cryptix Foundation Limited.
29  * All rights reserved.
30  *
31  * Use, modification, copying and distribution of this softwareas is subject
32  * the terms and conditions of the Cryptix General Licence. You should have
33  * received a copy of the Cryptix General Licence along with this library;
34  * if not, you can download a copy from http://www.cryptix.org/ .
35  */
36 
37 package com.sun.crypto.provider;
38 
39 import java.security.InvalidKeyException;
40 import java.security.MessageDigest;
41 import java.util.Arrays;
42 
43 import jdk.internal.vm.annotation.IntrinsicCandidate;
44 
45 /**
46  * Rijndael --pronounced Reindaal-- is a symmetric cipher with a 128-bit
47  * block size and variable key-size (128-, 192- and 256-bit).
48  * <p>
49  * Rijndael was designed by <a href="mailto:rijmen@esat.kuleuven.ac.be">Vincent
50  * Rijmen</a> and <a href="mailto:Joan.Daemen@village.uunet.be">Joan Daemen</a>.
51  */
52 final class AESCrypt extends SymmetricCipher implements AESConstants
53 {
54     private boolean ROUNDS_12 = false;
55     private boolean ROUNDS_14 = false;
56 
57     /** Session and Sub keys */
58     private int[][] sessionK = null;
59     private int[] K = null;
60 
61     /** Cipher encryption/decryption key */
62     // skip re-generating Session and Sub keys if the cipher key is
63     // the same
64     private byte[] lastKey = null;
65 
66     /** ROUNDS * 4 */
67     private int limit = 0;
68 
AESCrypt()69     AESCrypt() {
70         // empty
71     }
72 
73     /**
74      * Returns this cipher's block size.
75      *
76      * @return this cipher's block size
77      */
getBlockSize()78     int getBlockSize() {
79         return AES_BLOCK_SIZE;
80     }
81 
init(boolean decrypting, String algorithm, byte[] key)82     void init(boolean decrypting, String algorithm, byte[] key)
83             throws InvalidKeyException {
84         if (!algorithm.equalsIgnoreCase("AES")
85                     && !algorithm.equalsIgnoreCase("Rijndael")) {
86             throw new InvalidKeyException
87                 ("Wrong algorithm: AES or Rijndael required");
88         }
89         if (!isKeySizeValid(key.length)) {
90             throw new InvalidKeyException("Invalid AES key length: " +
91                 key.length + " bytes");
92         }
93 
94         if (!MessageDigest.isEqual(key, lastKey)) {
95             // re-generate session key 'sessionK' when cipher key changes
96             makeSessionKey(key);
97             lastKey = key.clone();  // save cipher key
98         }
99 
100         // set sub key to the corresponding session Key
101         this.K = sessionK[(decrypting? 1:0)];
102     }
103 
104     /**
105      * Expand an int[(ROUNDS+1)][4] into int[(ROUNDS+1)*4].
106      * For decryption round keys, need to rotate right by 4 ints.
107      * @param kr The round keys for encryption or decryption.
108      * @param decrypting True if 'kr' is for decryption and false otherwise.
109      */
expandToSubKey(int[][] kr, boolean decrypting)110     private static final int[] expandToSubKey(int[][] kr, boolean decrypting) {
111         int total = kr.length;
112         int[] expK = new int[total*4];
113         if (decrypting) {
114             // decrypting, rotate right by 4 ints
115             // i.e. i==0
116             for(int j=0; j<4; j++) {
117                 expK[j] = kr[total-1][j];
118             }
119             for(int i=1; i<total; i++) {
120                 for(int j=0; j<4; j++) {
121                     expK[i*4 + j] = kr[i-1][j];
122                 }
123             }
124         } else {
125             // encrypting, straight expansion
126             for(int i=0; i<total; i++) {
127                 for(int j=0; j<4; j++) {
128                     expK[i*4 + j] = kr[i][j];
129                 }
130             }
131         }
132         return expK;
133     }
134 
135     private static int[]
136         alog = new int[256],
137         log  = new int[256];
138 
139     private static final byte[]
140         S  = new byte[256],
141         Si = new byte[256];
142 
143     private static final int[]
144         T1 = new int[256],
145         T2 = new int[256],
146         T3 = new int[256],
147         T4 = new int[256],
148         T5 = new int[256],
149         T6 = new int[256],
150         T7 = new int[256],
151         T8 = new int[256];
152 
153     private static final int[]
154         U1 = new int[256],
155         U2 = new int[256],
156         U3 = new int[256],
157         U4 = new int[256];
158 
159     private static final byte[] rcon = new byte[30];
160 
161 
162     // Static code - to intialise S-boxes and T-boxes
163     static
164     {
165         int ROOT = 0x11B;
166         int i, j = 0;
167 
168         //
169         // produce log and alog tables, needed for multiplying in the
170         // field GF(2^m) (generator = 3)
171         //
172         alog[0] = 1;
173         for (i = 1; i < 256; i++)
174         {
175             j = (alog[i-1] << 1) ^ alog[i-1];
176             if ((j & 0x100) != 0) {
177                 j ^= ROOT;
178             }
179             alog[i] = j;
180         }
181         for (i = 1; i < 255; i++) {
182             log[alog[i]] = i;
183         }
184         byte[][] A = new byte[][]
185         {
186             {1, 1, 1, 1, 1, 0, 0, 0},
187             {0, 1, 1, 1, 1, 1, 0, 0},
188             {0, 0, 1, 1, 1, 1, 1, 0},
189             {0, 0, 0, 1, 1, 1, 1, 1},
190             {1, 0, 0, 0, 1, 1, 1, 1},
191             {1, 1, 0, 0, 0, 1, 1, 1},
192             {1, 1, 1, 0, 0, 0, 1, 1},
193             {1, 1, 1, 1, 0, 0, 0, 1}
194         };
195         byte[] B = new byte[] { 0, 1, 1, 0, 0, 0, 1, 1};
196 
197         //
198         // substitution box based on F^{-1}(x)
199         //
200         int t;
201         byte[][] box = new byte[256][8];
202         box[1][7] = 1;
203         for (i = 2; i < 256; i++) {
204             j = alog[255 - log[i]];
205             for (t = 0; t < 8; t++) {
206                 box[i][t] = (byte)((j >>> (7 - t)) & 0x01);
207             }
208         }
209         //
210         // affine transform:  box[i] <- B + A*box[i]
211         //
212         byte[][] cox = new byte[256][8];
213         for (i = 0; i < 256; i++) {
214             for (t = 0; t < 8; t++) {
215                 cox[i][t] = B[t];
216                 for (j = 0; j < 8; j++) {
217                     cox[i][t] ^= A[t][j] * box[i][j];
218                 }
219             }
220         }
221         //
222         // S-boxes and inverse S-boxes
223         //
224         for (i = 0; i < 256; i++) {
225             S[i] = (byte)(cox[i][0] << 7);
226             for (t = 1; t < 8; t++) {
227                     S[i] ^= cox[i][t] << (7-t);
228             }
229             Si[S[i] & 0xFF] = (byte) i;
230         }
231         //
232         // T-boxes
233         //
234         byte[][] G = new byte[][] {
235             {2, 1, 1, 3},
236             {3, 2, 1, 1},
237             {1, 3, 2, 1},
238             {1, 1, 3, 2}
239         };
240         byte[][] AA = new byte[4][8];
241         for (i = 0; i < 4; i++) {
242             for (j = 0; j < 4; j++) AA[i][j] = G[i][j];
243             AA[i][i+4] = 1;
244         }
245         byte pivot, tmp;
246         byte[][] iG = new byte[4][4];
247         for (i = 0; i < 4; i++) {
248             pivot = AA[i][i];
249             if (pivot == 0) {
250                 t = i + 1;
251                 while ((AA[t][i] == 0) && (t < 4)) {
252                     t++;
253                 }
254                 if (t == 4) {
255                     throw new RuntimeException("G matrix is not invertible");
256                 }
257                 else {
258                     for (j = 0; j < 8; j++) {
259                         tmp = AA[i][j];
260                         AA[i][j] = AA[t][j];
261                         AA[t][j] = tmp;
262                     }
263                     pivot = AA[i][i];
264                 }
265             }
266             for (j = 0; j < 8; j++) {
267                 if (AA[i][j] != 0) {
268                     AA[i][j] = (byte)
269                         alog[(255 + log[AA[i][j] & 0xFF] - log[pivot & 0xFF])
270                         % 255];
271                 }
272             }
273             for (t = 0; t < 4; t++) {
274                 if (i != t) {
275                     for (j = i+1; j < 8; j++) {
276                         AA[t][j] ^= mul(AA[i][j], AA[t][i]);
277                     }
278                     AA[t][i] = 0;
279                 }
280             }
281         }
282         for (i = 0; i < 4; i++) {
283             for (j = 0; j < 4; j++) {
284                 iG[i][j] = AA[i][j + 4];
285             }
286         }
287 
288         int s;
289         for (t = 0; t < 256; t++) {
290             s = S[t];
291             T1[t] = mul4(s, G[0]);
292             T2[t] = mul4(s, G[1]);
293             T3[t] = mul4(s, G[2]);
294             T4[t] = mul4(s, G[3]);
295 
296             s = Si[t];
297             T5[t] = mul4(s, iG[0]);
298             T6[t] = mul4(s, iG[1]);
299             T7[t] = mul4(s, iG[2]);
300             T8[t] = mul4(s, iG[3]);
301 
302             U1[t] = mul4(t, iG[0]);
303             U2[t] = mul4(t, iG[1]);
304             U3[t] = mul4(t, iG[2]);
305             U4[t] = mul4(t, iG[3]);
306         }
307         //
308         // round constants
309         //
310         rcon[0] = 1;
311         int r = 1;
312         for (t = 1; t < 30; t++) {
313             r = mul(2, r);
314             rcon[t] = (byte) r;
315         }
316         log = null;
317         alog = null;
318     }
319 
320     // multiply two elements of GF(2^m)
mul(int a, int b)321     private static final int mul (int a, int b) {
322         return (a != 0 && b != 0) ?
323             alog[(log[a & 0xFF] + log[b & 0xFF]) % 255] :
324             0;
325     }
326 
327     // convenience method used in generating Transposition boxes
mul4(int a, byte[] b)328     private static final int mul4 (int a, byte[] b) {
329         if (a == 0) return 0;
330         a = log[a & 0xFF];
331         int a0 = (b[0] != 0) ? alog[(a + log[b[0] & 0xFF]) % 255] & 0xFF : 0;
332         int a1 = (b[1] != 0) ? alog[(a + log[b[1] & 0xFF]) % 255] & 0xFF : 0;
333         int a2 = (b[2] != 0) ? alog[(a + log[b[2] & 0xFF]) % 255] & 0xFF : 0;
334         int a3 = (b[3] != 0) ? alog[(a + log[b[3] & 0xFF]) % 255] & 0xFF : 0;
335         return a0 << 24 | a1 << 16 | a2 << 8 | a3;
336     }
337 
338     // check if the specified length (in bytes) is a valid keysize for AES
isKeySizeValid(int len)339     static final boolean isKeySizeValid(int len) {
340         for (int i = 0; i < AES_KEYSIZES.length; i++) {
341             if (len == AES_KEYSIZES[i]) {
342                 return true;
343             }
344         }
345         return false;
346     }
347 
348     /**
349      * Encrypt exactly one block of plaintext.
350      */
encryptBlock(byte[] in, int inOffset, byte[] out, int outOffset)351     void encryptBlock(byte[] in, int inOffset,
352                       byte[] out, int outOffset) {
353         // Array bound checks are done in caller code, i.e.
354         // FeedbackCipher.encrypt/decrypt(...) to improve performance.
355         implEncryptBlock(in, inOffset, out, outOffset);
356     }
357 
358     // Encryption operation. Possibly replaced with a compiler intrinsic.
359     @IntrinsicCandidate
implEncryptBlock(byte[] in, int inOffset, byte[] out, int outOffset)360     private void implEncryptBlock(byte[] in, int inOffset,
361                                   byte[] out, int outOffset)
362     {
363         int keyOffset = 0;
364         int t0   = ((in[inOffset++]       ) << 24 |
365                     (in[inOffset++] & 0xFF) << 16 |
366                     (in[inOffset++] & 0xFF) <<  8 |
367                     (in[inOffset++] & 0xFF)        ) ^ K[keyOffset++];
368         int t1   = ((in[inOffset++]       ) << 24 |
369                     (in[inOffset++] & 0xFF) << 16 |
370                     (in[inOffset++] & 0xFF) <<  8 |
371                     (in[inOffset++] & 0xFF)        ) ^ K[keyOffset++];
372         int t2   = ((in[inOffset++]       ) << 24 |
373                     (in[inOffset++] & 0xFF) << 16 |
374                     (in[inOffset++] & 0xFF) <<  8 |
375                     (in[inOffset++] & 0xFF)        ) ^ K[keyOffset++];
376         int t3   = ((in[inOffset++]       ) << 24 |
377                     (in[inOffset++] & 0xFF) << 16 |
378                     (in[inOffset++] & 0xFF) <<  8 |
379                     (in[inOffset++] & 0xFF)        ) ^ K[keyOffset++];
380 
381         // apply round transforms
382         while( keyOffset < limit )
383         {
384             int a0, a1, a2;
385             a0 = T1[(t0 >>> 24)       ] ^
386                  T2[(t1 >>> 16) & 0xFF] ^
387                  T3[(t2 >>>  8) & 0xFF] ^
388                  T4[(t3       ) & 0xFF] ^ K[keyOffset++];
389             a1 = T1[(t1 >>> 24)       ] ^
390                  T2[(t2 >>> 16) & 0xFF] ^
391                  T3[(t3 >>>  8) & 0xFF] ^
392                  T4[(t0       ) & 0xFF] ^ K[keyOffset++];
393             a2 = T1[(t2 >>> 24)       ] ^
394                  T2[(t3 >>> 16) & 0xFF] ^
395                  T3[(t0 >>>  8) & 0xFF] ^
396                  T4[(t1       ) & 0xFF] ^ K[keyOffset++];
397             t3 = T1[(t3 >>> 24)       ] ^
398                  T2[(t0 >>> 16) & 0xFF] ^
399                  T3[(t1 >>>  8) & 0xFF] ^
400                  T4[(t2       ) & 0xFF] ^ K[keyOffset++];
401             t0 = a0; t1 = a1; t2 = a2;
402         }
403 
404         // last round is special
405         int tt = K[keyOffset++];
406         out[outOffset++] = (byte)(S[(t0 >>> 24)       ] ^ (tt >>> 24));
407         out[outOffset++] = (byte)(S[(t1 >>> 16) & 0xFF] ^ (tt >>> 16));
408         out[outOffset++] = (byte)(S[(t2 >>>  8) & 0xFF] ^ (tt >>>  8));
409         out[outOffset++] = (byte)(S[(t3       ) & 0xFF] ^ (tt       ));
410         tt = K[keyOffset++];
411         out[outOffset++] = (byte)(S[(t1 >>> 24)       ] ^ (tt >>> 24));
412         out[outOffset++] = (byte)(S[(t2 >>> 16) & 0xFF] ^ (tt >>> 16));
413         out[outOffset++] = (byte)(S[(t3 >>>  8) & 0xFF] ^ (tt >>>  8));
414         out[outOffset++] = (byte)(S[(t0       ) & 0xFF] ^ (tt       ));
415         tt = K[keyOffset++];
416         out[outOffset++] = (byte)(S[(t2 >>> 24)       ] ^ (tt >>> 24));
417         out[outOffset++] = (byte)(S[(t3 >>> 16) & 0xFF] ^ (tt >>> 16));
418         out[outOffset++] = (byte)(S[(t0 >>>  8) & 0xFF] ^ (tt >>>  8));
419         out[outOffset++] = (byte)(S[(t1       ) & 0xFF] ^ (tt       ));
420         tt = K[keyOffset++];
421         out[outOffset++] = (byte)(S[(t3 >>> 24)       ] ^ (tt >>> 24));
422         out[outOffset++] = (byte)(S[(t0 >>> 16) & 0xFF] ^ (tt >>> 16));
423         out[outOffset++] = (byte)(S[(t1 >>>  8) & 0xFF] ^ (tt >>>  8));
424         out[outOffset  ] = (byte)(S[(t2       ) & 0xFF] ^ (tt       ));
425     }
426 
427     /**
428      * Decrypt exactly one block of plaintext.
429      */
decryptBlock(byte[] in, int inOffset, byte[] out, int outOffset)430     void decryptBlock(byte[] in, int inOffset,
431                       byte[] out, int outOffset) {
432         // Array bound checks are done in caller code, i.e.
433         // FeedbackCipher.encrypt/decrypt(...) to improve performance.
434         implDecryptBlock(in, inOffset, out, outOffset);
435     }
436 
437     // Decrypt operation. Possibly replaced with a compiler intrinsic.
438     @IntrinsicCandidate
implDecryptBlock(byte[] in, int inOffset, byte[] out, int outOffset)439     private void implDecryptBlock(byte[] in, int inOffset,
440                                   byte[] out, int outOffset)
441     {
442         int keyOffset = 4;
443         int t0 = ((in[inOffset++]       ) << 24 |
444                   (in[inOffset++] & 0xFF) << 16 |
445                   (in[inOffset++] & 0xFF) <<  8 |
446                   (in[inOffset++] & 0xFF)        ) ^ K[keyOffset++];
447         int t1 = ((in[inOffset++]       ) << 24 |
448                   (in[inOffset++] & 0xFF) << 16 |
449                   (in[inOffset++] & 0xFF) <<  8 |
450                   (in[inOffset++] & 0xFF)        ) ^ K[keyOffset++];
451         int t2 = ((in[inOffset++]       ) << 24 |
452                   (in[inOffset++] & 0xFF) << 16 |
453                   (in[inOffset++] & 0xFF) <<  8 |
454                   (in[inOffset++] & 0xFF)        ) ^ K[keyOffset++];
455         int t3 = ((in[inOffset++]       ) << 24 |
456                   (in[inOffset++] & 0xFF) << 16 |
457                   (in[inOffset++] & 0xFF) <<  8 |
458                   (in[inOffset  ] & 0xFF)        ) ^ K[keyOffset++];
459 
460         int a0, a1, a2;
461         if(ROUNDS_12)
462         {
463             a0 = T5[(t0>>>24)     ] ^ T6[(t3>>>16)&0xFF] ^
464                  T7[(t2>>> 8)&0xFF] ^ T8[(t1     )&0xFF] ^ K[keyOffset++];
465             a1 = T5[(t1>>>24)     ] ^ T6[(t0>>>16)&0xFF] ^
466                  T7[(t3>>> 8)&0xFF] ^ T8[(t2     )&0xFF] ^ K[keyOffset++];
467             a2 = T5[(t2>>>24)     ] ^ T6[(t1>>>16)&0xFF] ^
468                  T7[(t0>>> 8)&0xFF] ^ T8[(t3     )&0xFF] ^ K[keyOffset++];
469             t3 = T5[(t3>>>24)     ] ^ T6[(t2>>>16)&0xFF] ^
470                  T7[(t1>>> 8)&0xFF] ^ T8[(t0     )&0xFF] ^ K[keyOffset++];
471             t0 = T5[(a0>>>24)     ] ^ T6[(t3>>>16)&0xFF] ^
472                  T7[(a2>>> 8)&0xFF] ^ T8[(a1     )&0xFF] ^ K[keyOffset++];
473             t1 = T5[(a1>>>24)     ] ^ T6[(a0>>>16)&0xFF] ^
474                  T7[(t3>>> 8)&0xFF] ^ T8[(a2     )&0xFF] ^ K[keyOffset++];
475             t2 = T5[(a2>>>24)     ] ^ T6[(a1>>>16)&0xFF] ^
476                  T7[(a0>>> 8)&0xFF] ^ T8[(t3     )&0xFF] ^ K[keyOffset++];
477             t3 = T5[(t3>>>24)     ] ^ T6[(a2>>>16)&0xFF] ^
478                  T7[(a1>>> 8)&0xFF] ^ T8[(a0     )&0xFF] ^ K[keyOffset++];
479 
480             if(ROUNDS_14)
481             {
482                 a0 = T5[(t0>>>24)     ] ^ T6[(t3>>>16)&0xFF] ^
483                      T7[(t2>>> 8)&0xFF] ^ T8[(t1     )&0xFF] ^ K[keyOffset++];
484                 a1 = T5[(t1>>>24)     ] ^ T6[(t0>>>16)&0xFF] ^
485                      T7[(t3>>> 8)&0xFF] ^ T8[(t2     )&0xFF] ^ K[keyOffset++];
486                 a2 = T5[(t2>>>24)     ] ^ T6[(t1>>>16)&0xFF] ^
487                      T7[(t0>>> 8)&0xFF] ^ T8[(t3     )&0xFF] ^ K[keyOffset++];
488                 t3 = T5[(t3>>>24)     ] ^ T6[(t2>>>16)&0xFF] ^
489                      T7[(t1>>> 8)&0xFF] ^ T8[(t0     )&0xFF] ^ K[keyOffset++];
490                 t0 = T5[(a0>>>24)     ] ^ T6[(t3>>>16)&0xFF] ^
491                      T7[(a2>>> 8)&0xFF] ^ T8[(a1     )&0xFF] ^ K[keyOffset++];
492                 t1 = T5[(a1>>>24)     ] ^ T6[(a0>>>16)&0xFF] ^
493                      T7[(t3>>> 8)&0xFF] ^ T8[(a2     )&0xFF] ^ K[keyOffset++];
494                 t2 = T5[(a2>>>24)     ] ^ T6[(a1>>>16)&0xFF] ^
495                      T7[(a0>>> 8)&0xFF] ^ T8[(t3     )&0xFF] ^ K[keyOffset++];
496                 t3 = T5[(t3>>>24)     ] ^ T6[(a2>>>16)&0xFF] ^
497                      T7[(a1>>> 8)&0xFF] ^ T8[(a0     )&0xFF] ^ K[keyOffset++];
498             }
499         }
500         a0 = T5[(t0>>>24)     ] ^ T6[(t3>>>16)&0xFF] ^
501              T7[(t2>>> 8)&0xFF] ^ T8[(t1     )&0xFF] ^ K[keyOffset++];
502         a1 = T5[(t1>>>24)     ] ^ T6[(t0>>>16)&0xFF] ^
503              T7[(t3>>> 8)&0xFF] ^ T8[(t2     )&0xFF] ^ K[keyOffset++];
504         a2 = T5[(t2>>>24)     ] ^ T6[(t1>>>16)&0xFF] ^
505              T7[(t0>>> 8)&0xFF] ^ T8[(t3     )&0xFF] ^ K[keyOffset++];
506         t3 = T5[(t3>>>24)     ] ^ T6[(t2>>>16)&0xFF] ^
507              T7[(t1>>> 8)&0xFF] ^ T8[(t0     )&0xFF] ^ K[keyOffset++];
508         t0 = T5[(a0>>>24)     ] ^ T6[(t3>>>16)&0xFF] ^
509              T7[(a2>>> 8)&0xFF] ^ T8[(a1     )&0xFF] ^ K[keyOffset++];
510         t1 = T5[(a1>>>24)     ] ^ T6[(a0>>>16)&0xFF] ^
511              T7[(t3>>> 8)&0xFF] ^ T8[(a2     )&0xFF] ^ K[keyOffset++];
512         t2 = T5[(a2>>>24)     ] ^ T6[(a1>>>16)&0xFF] ^
513              T7[(a0>>> 8)&0xFF] ^ T8[(t3     )&0xFF] ^ K[keyOffset++];
514         t3 = T5[(t3>>>24)     ] ^ T6[(a2>>>16)&0xFF] ^
515              T7[(a1>>> 8)&0xFF] ^ T8[(a0     )&0xFF] ^ K[keyOffset++];
516         a0 = T5[(t0>>>24)     ] ^ T6[(t3>>>16)&0xFF] ^
517              T7[(t2>>> 8)&0xFF] ^ T8[(t1     )&0xFF] ^ K[keyOffset++];
518         a1 = T5[(t1>>>24)     ] ^ T6[(t0>>>16)&0xFF] ^
519              T7[(t3>>> 8)&0xFF] ^ T8[(t2     )&0xFF] ^ K[keyOffset++];
520         a2 = T5[(t2>>>24)     ] ^ T6[(t1>>>16)&0xFF] ^
521              T7[(t0>>> 8)&0xFF] ^ T8[(t3     )&0xFF] ^ K[keyOffset++];
522         t3 = T5[(t3>>>24)     ] ^ T6[(t2>>>16)&0xFF] ^
523              T7[(t1>>> 8)&0xFF] ^ T8[(t0     )&0xFF] ^ K[keyOffset++];
524         t0 = T5[(a0>>>24)     ] ^ T6[(t3>>>16)&0xFF] ^
525              T7[(a2>>> 8)&0xFF] ^ T8[(a1     )&0xFF] ^ K[keyOffset++];
526         t1 = T5[(a1>>>24)     ] ^ T6[(a0>>>16)&0xFF] ^
527              T7[(t3>>> 8)&0xFF] ^ T8[(a2     )&0xFF] ^ K[keyOffset++];
528         t2 = T5[(a2>>>24)     ] ^ T6[(a1>>>16)&0xFF] ^
529              T7[(a0>>> 8)&0xFF] ^ T8[(t3     )&0xFF] ^ K[keyOffset++];
530         t3 = T5[(t3>>>24)     ] ^ T6[(a2>>>16)&0xFF] ^
531              T7[(a1>>> 8)&0xFF] ^ T8[(a0     )&0xFF] ^ K[keyOffset++];
532         a0 = T5[(t0>>>24)     ] ^ T6[(t3>>>16)&0xFF] ^
533              T7[(t2>>> 8)&0xFF] ^ T8[(t1     )&0xFF] ^ K[keyOffset++];
534         a1 = T5[(t1>>>24)     ] ^ T6[(t0>>>16)&0xFF] ^
535              T7[(t3>>> 8)&0xFF] ^ T8[(t2     )&0xFF] ^ K[keyOffset++];
536         a2 = T5[(t2>>>24)     ] ^ T6[(t1>>>16)&0xFF] ^
537              T7[(t0>>> 8)&0xFF] ^ T8[(t3     )&0xFF] ^ K[keyOffset++];
538         t3 = T5[(t3>>>24)     ] ^ T6[(t2>>>16)&0xFF] ^
539              T7[(t1>>> 8)&0xFF] ^ T8[(t0     )&0xFF] ^ K[keyOffset++];
540         t0 = T5[(a0>>>24)     ] ^ T6[(t3>>>16)&0xFF] ^
541              T7[(a2>>> 8)&0xFF] ^ T8[(a1     )&0xFF] ^ K[keyOffset++];
542         t1 = T5[(a1>>>24)     ] ^ T6[(a0>>>16)&0xFF] ^
543              T7[(t3>>> 8)&0xFF] ^ T8[(a2     )&0xFF] ^ K[keyOffset++];
544         t2 = T5[(a2>>>24)     ] ^ T6[(a1>>>16)&0xFF] ^
545              T7[(a0>>> 8)&0xFF] ^ T8[(t3     )&0xFF] ^ K[keyOffset++];
546         t3 = T5[(t3>>>24)     ] ^ T6[(a2>>>16)&0xFF] ^
547              T7[(a1>>> 8)&0xFF] ^ T8[(a0     )&0xFF] ^ K[keyOffset++];
548         a0 = T5[(t0>>>24)     ] ^ T6[(t3>>>16)&0xFF] ^
549              T7[(t2>>> 8)&0xFF] ^ T8[(t1     )&0xFF] ^ K[keyOffset++];
550         a1 = T5[(t1>>>24)     ] ^ T6[(t0>>>16)&0xFF] ^
551              T7[(t3>>> 8)&0xFF] ^ T8[(t2     )&0xFF] ^ K[keyOffset++];
552         a2 = T5[(t2>>>24)     ] ^ T6[(t1>>>16)&0xFF] ^
553              T7[(t0>>> 8)&0xFF] ^ T8[(t3     )&0xFF] ^ K[keyOffset++];
554         t3 = T5[(t3>>>24)     ] ^ T6[(t2>>>16)&0xFF] ^
555              T7[(t1>>> 8)&0xFF] ^ T8[(t0     )&0xFF] ^ K[keyOffset++];
556         t0 = T5[(a0>>>24)     ] ^ T6[(t3>>>16)&0xFF] ^
557              T7[(a2>>> 8)&0xFF] ^ T8[(a1     )&0xFF] ^ K[keyOffset++];
558         t1 = T5[(a1>>>24)     ] ^ T6[(a0>>>16)&0xFF] ^
559              T7[(t3>>> 8)&0xFF] ^ T8[(a2     )&0xFF] ^ K[keyOffset++];
560         t2 = T5[(a2>>>24)     ] ^ T6[(a1>>>16)&0xFF] ^
561              T7[(a0>>> 8)&0xFF] ^ T8[(t3     )&0xFF] ^ K[keyOffset++];
562         t3 = T5[(t3>>>24)     ] ^ T6[(a2>>>16)&0xFF] ^
563              T7[(a1>>> 8)&0xFF] ^ T8[(a0     )&0xFF] ^ K[keyOffset++];
564         a0 = T5[(t0>>>24)     ] ^ T6[(t3>>>16)&0xFF] ^
565              T7[(t2>>> 8)&0xFF] ^ T8[(t1     )&0xFF] ^ K[keyOffset++];
566         a1 = T5[(t1>>>24)     ] ^ T6[(t0>>>16)&0xFF] ^
567              T7[(t3>>> 8)&0xFF] ^ T8[(t2     )&0xFF] ^ K[keyOffset++];
568         a2 = T5[(t2>>>24)     ] ^ T6[(t1>>>16)&0xFF] ^
569              T7[(t0>>> 8)&0xFF] ^ T8[(t3     )&0xFF] ^ K[keyOffset++];
570         t3 = T5[(t3>>>24)     ] ^ T6[(t2>>>16)&0xFF] ^
571              T7[(t1>>> 8)&0xFF] ^ T8[(t0     )&0xFF] ^ K[keyOffset++];
572 
573         t1 = K[0];
574         out[outOffset++] = (byte)(Si[(a0 >>> 24)       ] ^ (t1 >>> 24));
575         out[outOffset++] = (byte)(Si[(t3 >>> 16) & 0xFF] ^ (t1 >>> 16));
576         out[outOffset++] = (byte)(Si[(a2 >>>  8) & 0xFF] ^ (t1 >>>  8));
577         out[outOffset++] = (byte)(Si[(a1       ) & 0xFF] ^ (t1       ));
578         t1 = K[1];
579         out[outOffset++] = (byte)(Si[(a1 >>> 24)       ] ^ (t1 >>> 24));
580         out[outOffset++] = (byte)(Si[(a0 >>> 16) & 0xFF] ^ (t1 >>> 16));
581         out[outOffset++] = (byte)(Si[(t3 >>>  8) & 0xFF] ^ (t1 >>>  8));
582         out[outOffset++] = (byte)(Si[(a2       ) & 0xFF] ^ (t1       ));
583         t1 = K[2];
584         out[outOffset++] = (byte)(Si[(a2 >>> 24)       ] ^ (t1 >>> 24));
585         out[outOffset++] = (byte)(Si[(a1 >>> 16) & 0xFF] ^ (t1 >>> 16));
586         out[outOffset++] = (byte)(Si[(a0 >>>  8) & 0xFF] ^ (t1 >>>  8));
587         out[outOffset++] = (byte)(Si[(t3       ) & 0xFF] ^ (t1       ));
588         t1 = K[3];
589         out[outOffset++] = (byte)(Si[(t3 >>> 24)       ] ^ (t1 >>> 24));
590         out[outOffset++] = (byte)(Si[(a2 >>> 16) & 0xFF] ^ (t1 >>> 16));
591         out[outOffset++] = (byte)(Si[(a1 >>>  8) & 0xFF] ^ (t1 >>>  8));
592         out[outOffset  ] = (byte)(Si[(a0       ) & 0xFF] ^ (t1       ));
593     }
594 
595     /**
596      * Expand a user-supplied key material into a session key.
597      *
598      * @param k The 128/192/256-bit cipher key to use.
599      * @exception InvalidKeyException  If the key is invalid.
600      */
makeSessionKey(byte[] k)601     private void makeSessionKey(byte[] k) throws InvalidKeyException {
602         if (k == null) {
603             throw new InvalidKeyException("Empty key");
604         }
605         if (!isKeySizeValid(k.length)) {
606              throw new InvalidKeyException("Invalid AES key length: " +
607                                            k.length + " bytes");
608         }
609         int ROUNDS          = getRounds(k.length);
610         int ROUND_KEY_COUNT = (ROUNDS + 1) * 4;
611 
612         int BC = 4;
613         int[][] Ke = new int[ROUNDS + 1][4]; // encryption round keys
614         int[][] Kd = new int[ROUNDS + 1][4]; // decryption round keys
615 
616         int KC = k.length/4; // keylen in 32-bit elements
617 
618         int[] tk = new int[KC];
619         int i, j;
620 
621         // copy user material bytes into temporary ints
622         for (i = 0, j = 0; i < KC; i++, j+=4) {
623             tk[i] = (k[j]       ) << 24 |
624                     (k[j+1] & 0xFF) << 16 |
625                     (k[j+2] & 0xFF) <<  8 |
626                     (k[j+3] & 0xFF);
627         }
628 
629         // copy values into round key arrays
630         int t = 0;
631         for (j = 0; (j < KC) && (t < ROUND_KEY_COUNT); j++, t++) {
632             Ke[t / 4][t % 4] = tk[j];
633             Kd[ROUNDS - (t / 4)][t % 4] = tk[j];
634         }
635         int tt, rconpointer = 0;
636         while (t < ROUND_KEY_COUNT) {
637             // extrapolate using phi (the round key evolution function)
638             tt = tk[KC - 1];
639             tk[0] ^= (S[(tt >>> 16) & 0xFF]       ) << 24 ^
640                      (S[(tt >>>  8) & 0xFF] & 0xFF) << 16 ^
641                      (S[(tt       ) & 0xFF] & 0xFF) <<  8 ^
642                      (S[(tt >>> 24)       ] & 0xFF)       ^
643                      (rcon[rconpointer++]         ) << 24;
644             if (KC != 8)
645                 for (i = 1, j = 0; i < KC; i++, j++) tk[i] ^= tk[j];
646             else {
647                 for (i = 1, j = 0; i < KC / 2; i++, j++) tk[i] ^= tk[j];
648                 tt = tk[KC / 2 - 1];
649                 tk[KC / 2] ^= (S[(tt       ) & 0xFF] & 0xFF)       ^
650                               (S[(tt >>>  8) & 0xFF] & 0xFF) <<  8 ^
651                               (S[(tt >>> 16) & 0xFF] & 0xFF) << 16 ^
652                               (S[(tt >>> 24)       ]       ) << 24;
653                 for (j = KC / 2, i = j + 1; i < KC; i++, j++) tk[i] ^= tk[j];
654             }
655             // copy values into round key arrays
656             for (j = 0; (j < KC) && (t < ROUND_KEY_COUNT); j++, t++) {
657                 Ke[t / 4][t % 4] = tk[j];
658                 Kd[ROUNDS - (t / 4)][t % 4] = tk[j];
659             }
660         }
661         for (int r = 1; r < ROUNDS; r++) {
662             // inverse MixColumn where needed
663             for (j = 0; j < BC; j++) {
664                 tt = Kd[r][j];
665                 Kd[r][j] = U1[(tt >>> 24) & 0xFF] ^
666                            U2[(tt >>> 16) & 0xFF] ^
667                            U3[(tt >>>  8) & 0xFF] ^
668                            U4[ tt         & 0xFF];
669             }
670         }
671 
672         // assemble the encryption (Ke) and decryption (Kd) round keys
673         // and expand them into arrays of ints.
674         int[] expandedKe = expandToSubKey(Ke, false); // decrypting==false
675         int[] expandedKd = expandToSubKey(Kd, true);  // decrypting==true
676 
677         ROUNDS_12 = (ROUNDS>=12);
678         ROUNDS_14 = (ROUNDS==14);
679         limit = ROUNDS*4;
680 
681         // store the expanded sub keys into 'sessionK'
682         sessionK = new int[][] { expandedKe, expandedKd };
683     }
684 
685 
686     /**
687      * Return The number of rounds for a given Rijndael keysize.
688      *
689      * @param keySize  The size of the user key material in bytes.
690      *                 MUST be one of (16, 24, 32).
691      * @return         The number of rounds.
692      */
getRounds(int keySize)693     private static int getRounds(int keySize) {
694         return (keySize >> 2) + 6;
695     }
696 }
697