1 /*
2  * Copyright (c) 2001, Dr Brian Gladman <brg@gladman.uk.net>, Worcester, UK.
3  * All rights reserved.
4  *
5  * LICENSE TERMS
6  *
7  * The free distribution and use of this software in both source and binary
8  * form is allowed (with or without changes) provided that:
9  *
10  *   1. distributions of this source code include the above copyright
11  *      notice, this list of conditions and the following disclaimer;
12  *
13  *   2. distributions in binary form include the above copyright
14  *      notice, this list of conditions and the following disclaimer
15  *      in the documentation and/or other associated materials;
16  *
17  *   3. the copyright holder's name is not used to endorse products
18  *      built using this software without specific written permission.
19  *
20  * DISCLAIMER
21  *
22  * This software is provided 'as is' with no explcit or implied warranties
23  * in respect of any properties, including, but not limited to, correctness
24  * and fitness for purpose.
25  */
26 
27 /*
28  * Issue Date: 21/01/2002
29  *
30  * This file contains the code for implementing the key schedule for AES
31  * (Rijndael) for block and key sizes of 16, 24, and 32 bytes.
32  */
33 
34 #include "aesopt.h"
35 
36 #if defined(BLOCK_SIZE) && (BLOCK_SIZE & 7)
37 #error An illegal block size has been specified.
38 #endif
39 
40 /* Subroutine to set the block size (if variable) in bytes, legal
41    values being 16, 24 and 32.
42 */
43 
44 #if !defined(BLOCK_SIZE) && defined(SET_BLOCK_LENGTH)
45 
aes_blk_len(unsigned int blen,aes_ctx cx[1])46 aes_rval aes_blk_len(unsigned int blen, aes_ctx cx[1])
47 {
48 #if !defined(FIXED_TABLES)
49     if(!tab_init) gen_tabs();
50 #endif
51 
52     if((blen & 7) || blen < 16 || blen > 32)
53     {
54         cx->n_blk = 0; return aes_bad;
55     }
56 
57     cx->n_blk = blen;
58     return aes_good;
59 }
60 
61 #endif
62 
63 /* Initialise the key schedule from the user supplied key. The key
64    length is now specified in bytes - 16, 24 or 32 as appropriate.
65    This corresponds to bit lengths of 128, 192 and 256 bits, and
66    to Nk values of 4, 6 and 8 respectively.
67 
68    The following macros implement a single cycle in the key
69    schedule generation process. The number of cycles needed
70    for each cx->n_col and nk value is:
71 
72     nk =             4  5  6  7  8
73     ------------------------------
74     cx->n_col = 4   10  9  8  7  7
75     cx->n_col = 5   14 11 10  9  9
76     cx->n_col = 6   19 15 12 11 11
77     cx->n_col = 7   21 19 16 13 14
78     cx->n_col = 8   29 23 19 17 14
79 */
80 
81 #if defined(ENCRYPTION_KEY_SCHEDULE)
82 
83 #define ke4(k,i) \
84 {   k[4*(i)+4] = ss[0] ^= ls_box(ss[3],3) ^ rcon_tab[i]; k[4*(i)+5] = ss[1] ^= ss[0]; \
85     k[4*(i)+6] = ss[2] ^= ss[1]; k[4*(i)+7] = ss[3] ^= ss[2]; \
86 }
87 #define kel4(k,i) \
88 {   k[4*(i)+4] = ss[0] ^= ls_box(ss[3],3) ^ rcon_tab[i]; k[4*(i)+5] = ss[1] ^= ss[0]; \
89     k[4*(i)+6] = ss[2] ^= ss[1]; k[4*(i)+7] = ss[3] ^= ss[2]; \
90 }
91 
92 #define ke6(k,i) \
93 {   k[6*(i)+ 6] = ss[0] ^= ls_box(ss[5],3) ^ rcon_tab[i]; k[6*(i)+ 7] = ss[1] ^= ss[0]; \
94     k[6*(i)+ 8] = ss[2] ^= ss[1]; k[6*(i)+ 9] = ss[3] ^= ss[2]; \
95     k[6*(i)+10] = ss[4] ^= ss[3]; k[6*(i)+11] = ss[5] ^= ss[4]; \
96 }
97 #define kel6(k,i) \
98 {   k[6*(i)+ 6] = ss[0] ^= ls_box(ss[5],3) ^ rcon_tab[i]; k[6*(i)+ 7] = ss[1] ^= ss[0]; \
99     k[6*(i)+ 8] = ss[2] ^= ss[1]; k[6*(i)+ 9] = ss[3] ^= ss[2]; \
100 }
101 
102 #define ke8(k,i) \
103 {   k[8*(i)+ 8] = ss[0] ^= ls_box(ss[7],3) ^ rcon_tab[i]; k[8*(i)+ 9] = ss[1] ^= ss[0]; \
104     k[8*(i)+10] = ss[2] ^= ss[1]; k[8*(i)+11] = ss[3] ^= ss[2]; \
105     k[8*(i)+12] = ss[4] ^= ls_box(ss[3],0); k[8*(i)+13] = ss[5] ^= ss[4]; \
106     k[8*(i)+14] = ss[6] ^= ss[5]; k[8*(i)+15] = ss[7] ^= ss[6]; \
107 }
108 #define kel8(k,i) \
109 {   k[8*(i)+ 8] = ss[0] ^= ls_box(ss[7],3) ^ rcon_tab[i]; k[8*(i)+ 9] = ss[1] ^= ss[0]; \
110     k[8*(i)+10] = ss[2] ^= ss[1]; k[8*(i)+11] = ss[3] ^= ss[2]; \
111 }
112 
aes_enc_key(const unsigned char in_key[],unsigned int klen,aes_ctx cx[1])113 aes_rval aes_enc_key(const unsigned char in_key[], unsigned int klen, aes_ctx cx[1])
114 {   uint32_t    ss[8];
115 
116 #if !defined(FIXED_TABLES)
117     if(!tab_init) gen_tabs();
118 #endif
119 
120 #if !defined(BLOCK_SIZE)
121     if(!cx->n_blk) cx->n_blk = 16;
122 #else
123     cx->n_blk = BLOCK_SIZE;
124 #endif
125 
126     cx->n_blk = (cx->n_blk & ~3U) | 1;
127 
128     cx->k_sch[0] = ss[0] = word_in(in_key     );
129     cx->k_sch[1] = ss[1] = word_in(in_key +  4);
130     cx->k_sch[2] = ss[2] = word_in(in_key +  8);
131     cx->k_sch[3] = ss[3] = word_in(in_key + 12);
132 
133 #if (BLOCK_SIZE == 16) && (ENC_UNROLL != NONE)
134 
135     switch(klen)
136     {
137     case 16:    ke4(cx->k_sch, 0); ke4(cx->k_sch, 1);
138                 ke4(cx->k_sch, 2); ke4(cx->k_sch, 3);
139                 ke4(cx->k_sch, 4); ke4(cx->k_sch, 5);
140                 ke4(cx->k_sch, 6); ke4(cx->k_sch, 7);
141                 ke4(cx->k_sch, 8); kel4(cx->k_sch, 9);
142                 cx->n_rnd = 10; break;
143     case 24:    cx->k_sch[4] = ss[4] = word_in(in_key + 16);
144                 cx->k_sch[5] = ss[5] = word_in(in_key + 20);
145                 ke6(cx->k_sch, 0); ke6(cx->k_sch, 1);
146                 ke6(cx->k_sch, 2); ke6(cx->k_sch, 3);
147                 ke6(cx->k_sch, 4); ke6(cx->k_sch, 5);
148                 ke6(cx->k_sch, 6); kel6(cx->k_sch, 7);
149                 cx->n_rnd = 12; break;
150     case 32:    cx->k_sch[4] = ss[4] = word_in(in_key + 16);
151                 cx->k_sch[5] = ss[5] = word_in(in_key + 20);
152                 cx->k_sch[6] = ss[6] = word_in(in_key + 24);
153                 cx->k_sch[7] = ss[7] = word_in(in_key + 28);
154                 ke8(cx->k_sch, 0); ke8(cx->k_sch, 1);
155                 ke8(cx->k_sch, 2); ke8(cx->k_sch, 3);
156                 ke8(cx->k_sch, 4); ke8(cx->k_sch, 5);
157                 kel8(cx->k_sch, 6);
158                 cx->n_rnd = 14; break;
159     default:    cx->n_rnd = 0; return aes_bad;
160     }
161 #else
162     {   uint32_t i, l;
163         cx->n_rnd = ((klen >> 2) > nc ? (klen >> 2) : nc) + 6;
164         l = (nc * cx->n_rnd + nc - 1) / (klen >> 2);
165 
166         switch(klen)
167         {
168         case 16:    for(i = 0; i < l; ++i)
169                         ke4(cx->k_sch, i);
170                     break;
171         case 24:    cx->k_sch[4] = ss[4] = word_in(in_key + 16);
172                     cx->k_sch[5] = ss[5] = word_in(in_key + 20);
173                     for(i = 0; i < l; ++i)
174                         ke6(cx->k_sch, i);
175                     break;
176         case 32:    cx->k_sch[4] = ss[4] = word_in(in_key + 16);
177                     cx->k_sch[5] = ss[5] = word_in(in_key + 20);
178                     cx->k_sch[6] = ss[6] = word_in(in_key + 24);
179                     cx->k_sch[7] = ss[7] = word_in(in_key + 28);
180                     for(i = 0; i < l; ++i)
181                         ke8(cx->k_sch,  i);
182                     break;
183         default:    cx->n_rnd = 0; return aes_bad;
184         }
185     }
186 #endif
187 
188     return aes_good;
189 }
190 
191 #endif
192 
193 #if defined(DECRYPTION_KEY_SCHEDULE)
194 
195 #if (DEC_ROUND != NO_TABLES)
196 #define d_vars  dec_imvars
197 #define ff(x)   inv_mcol(x)
198 #else
199 #define ff(x)   (x)
200 #define d_vars
201 #endif
202 
203 #if 1
204 #define kdf4(k,i) \
205 {   ss[0] = ss[0] ^ ss[2] ^ ss[1] ^ ss[3]; ss[1] = ss[1] ^ ss[3]; ss[2] = ss[2] ^ ss[3]; ss[3] = ss[3]; \
206     ss[4] = ls_box(ss[(i+3) % 4], 3) ^ rcon_tab[i]; ss[i % 4] ^= ss[4]; \
207     ss[4] ^= k[4*(i)];   k[4*(i)+4] = ff(ss[4]); ss[4] ^= k[4*(i)+1]; k[4*(i)+5] = ff(ss[4]); \
208     ss[4] ^= k[4*(i)+2]; k[4*(i)+6] = ff(ss[4]); ss[4] ^= k[4*(i)+3]; k[4*(i)+7] = ff(ss[4]); \
209 }
210 #define kd4(k,i) \
211 {   ss[4] = ls_box(ss[(i+3) % 4], 3) ^ rcon_tab[i]; ss[i % 4] ^= ss[4]; ss[4] = ff(ss[4]); \
212     k[4*(i)+4] = ss[4] ^= k[4*(i)]; k[4*(i)+5] = ss[4] ^= k[4*(i)+1]; \
213     k[4*(i)+6] = ss[4] ^= k[4*(i)+2]; k[4*(i)+7] = ss[4] ^= k[4*(i)+3]; \
214 }
215 #define kdl4(k,i) \
216 {   ss[4] = ls_box(ss[(i+3) % 4], 3) ^ rcon_tab[i]; ss[i % 4] ^= ss[4]; \
217     k[4*(i)+4] = (ss[0] ^= ss[1]) ^ ss[2] ^ ss[3]; k[4*(i)+5] = ss[1] ^ ss[3]; \
218     k[4*(i)+6] = ss[0]; k[4*(i)+7] = ss[1]; \
219 }
220 #else
221 #define kdf4(k,i) \
222 {   ss[0] ^= ls_box(ss[3],3) ^ rcon_tab[i]; k[4*(i)+ 4] = ff(ss[0]); ss[1] ^= ss[0]; k[4*(i)+ 5] = ff(ss[1]); \
223     ss[2] ^= ss[1]; k[4*(i)+ 6] = ff(ss[2]); ss[3] ^= ss[2]; k[4*(i)+ 7] = ff(ss[3]); \
224 }
225 #define kd4(k,i) \
226 {   ss[4] = ls_box(ss[3],3) ^ rcon_tab[i]; \
227     ss[0] ^= ss[4]; ss[4] = ff(ss[4]); k[4*(i)+ 4] = ss[4] ^= k[4*(i)]; \
228     ss[1] ^= ss[0]; k[4*(i)+ 5] = ss[4] ^= k[4*(i)+ 1]; \
229     ss[2] ^= ss[1]; k[4*(i)+ 6] = ss[4] ^= k[4*(i)+ 2]; \
230     ss[3] ^= ss[2]; k[4*(i)+ 7] = ss[4] ^= k[4*(i)+ 3]; \
231 }
232 #define kdl4(k,i) \
233 {   ss[0] ^= ls_box(ss[3],3) ^ rcon_tab[i]; k[4*(i)+ 4] = ss[0]; ss[1] ^= ss[0]; k[4*(i)+ 5] = ss[1]; \
234     ss[2] ^= ss[1]; k[4*(i)+ 6] = ss[2]; ss[3] ^= ss[2]; k[4*(i)+ 7] = ss[3]; \
235 }
236 #endif
237 
238 #define kdf6(k,i) \
239 {   ss[0] ^= ls_box(ss[5],3) ^ rcon_tab[i]; k[6*(i)+ 6] = ff(ss[0]); ss[1] ^= ss[0]; k[6*(i)+ 7] = ff(ss[1]); \
240     ss[2] ^= ss[1]; k[6*(i)+ 8] = ff(ss[2]); ss[3] ^= ss[2]; k[6*(i)+ 9] = ff(ss[3]); \
241     ss[4] ^= ss[3]; k[6*(i)+10] = ff(ss[4]); ss[5] ^= ss[4]; k[6*(i)+11] = ff(ss[5]); \
242 }
243 #define kd6(k,i) \
244 {   ss[6] = ls_box(ss[5],3) ^ rcon_tab[i]; \
245     ss[0] ^= ss[6]; ss[6] = ff(ss[6]); k[6*(i)+ 6] = ss[6] ^= k[6*(i)]; \
246     ss[1] ^= ss[0]; k[6*(i)+ 7] = ss[6] ^= k[6*(i)+ 1]; \
247     ss[2] ^= ss[1]; k[6*(i)+ 8] = ss[6] ^= k[6*(i)+ 2]; \
248     ss[3] ^= ss[2]; k[6*(i)+ 9] = ss[6] ^= k[6*(i)+ 3]; \
249     ss[4] ^= ss[3]; k[6*(i)+10] = ss[6] ^= k[6*(i)+ 4]; \
250     ss[5] ^= ss[4]; k[6*(i)+11] = ss[6] ^= k[6*(i)+ 5]; \
251 }
252 #define kdl6(k,i) \
253 {   ss[0] ^= ls_box(ss[5],3) ^ rcon_tab[i]; k[6*(i)+ 6] = ss[0]; ss[1] ^= ss[0]; k[6*(i)+ 7] = ss[1]; \
254     ss[2] ^= ss[1]; k[6*(i)+ 8] = ss[2]; ss[3] ^= ss[2]; k[6*(i)+ 9] = ss[3]; \
255 }
256 
257 #define kdf8(k,i) \
258 {   ss[0] ^= ls_box(ss[7],3) ^ rcon_tab[i]; k[8*(i)+ 8] = ff(ss[0]); ss[1] ^= ss[0]; k[8*(i)+ 9] = ff(ss[1]); \
259     ss[2] ^= ss[1]; k[8*(i)+10] = ff(ss[2]); ss[3] ^= ss[2]; k[8*(i)+11] = ff(ss[3]); \
260     ss[4] ^= ls_box(ss[3],0); k[8*(i)+12] = ff(ss[4]); ss[5] ^= ss[4]; k[8*(i)+13] = ff(ss[5]); \
261     ss[6] ^= ss[5]; k[8*(i)+14] = ff(ss[6]); ss[7] ^= ss[6]; k[8*(i)+15] = ff(ss[7]); \
262 }
263 #define kd8(k,i) \
264 {   uint32_t g = ls_box(ss[7],3) ^ rcon_tab[i]; \
265     ss[0] ^= g; g = ff(g); k[8*(i)+ 8] = g ^= k[8*(i)]; \
266     ss[1] ^= ss[0]; k[8*(i)+ 9] = g ^= k[8*(i)+ 1]; \
267     ss[2] ^= ss[1]; k[8*(i)+10] = g ^= k[8*(i)+ 2]; \
268     ss[3] ^= ss[2]; k[8*(i)+11] = g ^= k[8*(i)+ 3]; \
269     g = ls_box(ss[3],0); \
270     ss[4] ^= g; g = ff(g); k[8*(i)+12] = g ^= k[8*(i)+ 4]; \
271     ss[5] ^= ss[4]; k[8*(i)+13] = g ^= k[8*(i)+ 5]; \
272     ss[6] ^= ss[5]; k[8*(i)+14] = g ^= k[8*(i)+ 6]; \
273     ss[7] ^= ss[6]; k[8*(i)+15] = g ^= k[8*(i)+ 7]; \
274 }
275 #define kdl8(k,i) \
276 {   ss[0] ^= ls_box(ss[7],3) ^ rcon_tab[i]; k[8*(i)+ 8] = ss[0]; ss[1] ^= ss[0]; k[8*(i)+ 9] = ss[1]; \
277     ss[2] ^= ss[1]; k[8*(i)+10] = ss[2]; ss[3] ^= ss[2]; k[8*(i)+11] = ss[3]; \
278 }
279 
aes_dec_key(const unsigned char in_key[],unsigned int klen,aes_ctx cx[1])280 aes_rval aes_dec_key(const unsigned char in_key[], unsigned int klen, aes_ctx cx[1])
281 {   uint32_t    ss[8];
282     d_vars
283 
284 #if !defined(FIXED_TABLES)
285     if(!tab_init) gen_tabs();
286 #endif
287 
288 #if !defined(BLOCK_SIZE)
289     if(!cx->n_blk) cx->n_blk = 16;
290 #else
291     cx->n_blk = BLOCK_SIZE;
292 #endif
293 
294     cx->n_blk = (cx->n_blk & ~3U) | 2;
295 
296     cx->k_sch[0] = ss[0] = word_in(in_key     );
297     cx->k_sch[1] = ss[1] = word_in(in_key +  4);
298     cx->k_sch[2] = ss[2] = word_in(in_key +  8);
299     cx->k_sch[3] = ss[3] = word_in(in_key + 12);
300 
301 #if (BLOCK_SIZE == 16) && (DEC_UNROLL != NONE)
302 
303     switch(klen)
304     {
305     case 16:    kdf4(cx->k_sch, 0); kd4(cx->k_sch, 1);
306                 kd4(cx->k_sch, 2); kd4(cx->k_sch, 3);
307                 kd4(cx->k_sch, 4); kd4(cx->k_sch, 5);
308                 kd4(cx->k_sch, 6); kd4(cx->k_sch, 7);
309                 kd4(cx->k_sch, 8); kdl4(cx->k_sch, 9);
310                 cx->n_rnd = 10; break;
311     case 24:    ss[4] = word_in(in_key + 16);
312 		cx->k_sch[4] = ff(ss[4]);
313 		ss[5] = word_in(in_key + 20);
314                 cx->k_sch[5] = ff(ss[5]);
315                 kdf6(cx->k_sch, 0); kd6(cx->k_sch, 1);
316                 kd6(cx->k_sch, 2); kd6(cx->k_sch, 3);
317                 kd6(cx->k_sch, 4); kd6(cx->k_sch, 5);
318                 kd6(cx->k_sch, 6); kdl6(cx->k_sch, 7);
319                 cx->n_rnd = 12; break;
320     case 32:    ss[4] = word_in(in_key + 16);
321 		cx->k_sch[4] = ff(ss[4]);
322 		ss[5] = word_in(in_key + 20);
323                 cx->k_sch[5] = ff(ss[5]);
324 		ss[6] = word_in(in_key + 24);
325                 cx->k_sch[6] = ff(ss[6]);
326 		ss[7] = word_in(in_key + 28);
327                 cx->k_sch[7] = ff(ss[7]);
328                 kdf8(cx->k_sch, 0); kd8(cx->k_sch, 1);
329                 kd8(cx->k_sch, 2); kd8(cx->k_sch, 3);
330                 kd8(cx->k_sch, 4); kd8(cx->k_sch, 5);
331                 kdl8(cx->k_sch, 6);
332                 cx->n_rnd = 14; break;
333     default:    cx->n_rnd = 0; return aes_bad;
334     }
335 #else
336     {   uint32_t i, l;
337         cx->n_rnd = ((klen >> 2) > nc ? (klen >> 2) : nc) + 6;
338         l = (nc * cx->n_rnd + nc - 1) / (klen >> 2);
339 
340         switch(klen)
341         {
342         case 16:
343                     for(i = 0; i < l; ++i)
344                         ke4(cx->k_sch, i);
345                     break;
346         case 24:    cx->k_sch[4] = ss[4] = word_in(in_key + 16);
347                     cx->k_sch[5] = ss[5] = word_in(in_key + 20);
348                     for(i = 0; i < l; ++i)
349                         ke6(cx->k_sch, i);
350                     break;
351         case 32:    cx->k_sch[4] = ss[4] = word_in(in_key + 16);
352                     cx->k_sch[5] = ss[5] = word_in(in_key + 20);
353                     cx->k_sch[6] = ss[6] = word_in(in_key + 24);
354                     cx->k_sch[7] = ss[7] = word_in(in_key + 28);
355                     for(i = 0; i < l; ++i)
356                         ke8(cx->k_sch,  i);
357                     break;
358         default:    cx->n_rnd = 0; return aes_bad;
359         }
360 #if (DEC_ROUND != NO_TABLES)
361         for(i = nc; i < nc * cx->n_rnd; ++i)
362             cx->k_sch[i] = inv_mcol(cx->k_sch[i]);
363 #endif
364     }
365 #endif
366 
367     return aes_good;
368 }
369 
370 #endif
371