1 /* This Source Code Form is subject to the terms of the Mozilla Public
2  * License, v. 2.0. If a copy of the MPL was not distributed with this
3  * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
4 
5 #ifdef FREEBL_NO_DEPEND
6 #include "stubs.h"
7 #endif
8 
9 #include "prinit.h"
10 #include "prenv.h"
11 #include "prerr.h"
12 #include "secerr.h"
13 
14 #include "prtypes.h"
15 #include "blapi.h"
16 #include "rijndael.h"
17 
18 #include "cts.h"
19 #include "ctr.h"
20 #include "gcm.h"
21 #include "mpi.h"
22 
23 #if !defined(IS_LITTLE_ENDIAN) && !defined(NSS_X86_OR_X64)
24 // not test yet on big endian platform of arm
25 #undef USE_HW_AES
26 #endif
27 
28 #ifdef USE_HW_AES
29 #ifdef NSS_X86_OR_X64
30 #include "intel-aes.h"
31 #else
32 #include "aes-armv8.h"
33 #endif
34 #endif /* USE_HW_AES */
35 #ifdef INTEL_GCM
36 #include "intel-gcm.h"
37 #endif /* INTEL_GCM */
38 
39 /* Forward declarations */
40 void rijndael_native_key_expansion(AESContext *cx, const unsigned char *key,
41                                    unsigned int Nk);
42 void rijndael_native_encryptBlock(AESContext *cx,
43                                   unsigned char *output,
44                                   const unsigned char *input);
45 void rijndael_native_decryptBlock(AESContext *cx,
46                                   unsigned char *output,
47                                   const unsigned char *input);
48 void native_xorBlock(unsigned char *out,
49                      const unsigned char *a,
50                      const unsigned char *b);
51 
52 /* Stub definitions for the above rijndael_native_* functions, which
53  * shouldn't be used unless NSS_X86_OR_X64 is defined */
54 #ifndef NSS_X86_OR_X64
55 void
rijndael_native_key_expansion(AESContext * cx,const unsigned char * key,unsigned int Nk)56 rijndael_native_key_expansion(AESContext *cx, const unsigned char *key,
57                               unsigned int Nk)
58 {
59     PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
60     PORT_Assert(0);
61 }
62 
63 void
rijndael_native_encryptBlock(AESContext * cx,unsigned char * output,const unsigned char * input)64 rijndael_native_encryptBlock(AESContext *cx,
65                              unsigned char *output,
66                              const unsigned char *input)
67 {
68     PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
69     PORT_Assert(0);
70 }
71 
72 void
rijndael_native_decryptBlock(AESContext * cx,unsigned char * output,const unsigned char * input)73 rijndael_native_decryptBlock(AESContext *cx,
74                              unsigned char *output,
75                              const unsigned char *input)
76 {
77     PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
78     PORT_Assert(0);
79 }
80 
81 void
native_xorBlock(unsigned char * out,const unsigned char * a,const unsigned char * b)82 native_xorBlock(unsigned char *out, const unsigned char *a,
83                 const unsigned char *b)
84 {
85     PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
86     PORT_Assert(0);
87 }
88 #endif /* NSS_X86_OR_X64 */
89 
90 /*
91  * There are currently three ways to build this code, varying in performance
92  * and code size.
93  *
94  * RIJNDAEL_INCLUDE_TABLES         Include all tables from rijndael32.tab
95  * RIJNDAEL_GENERATE_VALUES        Do not store tables, generate the table
96  *                                 values "on-the-fly", using gfm
97  * RIJNDAEL_GENERATE_VALUES_MACRO  Same as above, but use macros
98  *
99  * The default is RIJNDAEL_INCLUDE_TABLES.
100  */
101 
102 /*
103  * When building RIJNDAEL_INCLUDE_TABLES, includes S**-1, Rcon, T[0..4],
104  *                                                 T**-1[0..4], IMXC[0..4]
105  * When building anything else, includes S, S**-1, Rcon
106  */
107 #include "rijndael32.tab"
108 
109 #if defined(RIJNDAEL_INCLUDE_TABLES)
110 /*
111  * RIJNDAEL_INCLUDE_TABLES
112  */
113 #define T0(i) _T0[i]
114 #define T1(i) _T1[i]
115 #define T2(i) _T2[i]
116 #define T3(i) _T3[i]
117 #define TInv0(i) _TInv0[i]
118 #define TInv1(i) _TInv1[i]
119 #define TInv2(i) _TInv2[i]
120 #define TInv3(i) _TInv3[i]
121 #define IMXC0(b) _IMXC0[b]
122 #define IMXC1(b) _IMXC1[b]
123 #define IMXC2(b) _IMXC2[b]
124 #define IMXC3(b) _IMXC3[b]
125 /* The S-box can be recovered from the T-tables */
126 #ifdef IS_LITTLE_ENDIAN
127 #define SBOX(b) ((PRUint8)_T3[b])
128 #else
129 #define SBOX(b) ((PRUint8)_T1[b])
130 #endif
131 #define SINV(b) (_SInv[b])
132 
133 #else /* not RIJNDAEL_INCLUDE_TABLES */
134 
135 /*
136  * Code for generating T-table values.
137  */
138 
139 #ifdef IS_LITTLE_ENDIAN
140 #define WORD4(b0, b1, b2, b3) \
141     ((((PRUint32)b3) << 24) | \
142      (((PRUint32)b2) << 16) | \
143      (((PRUint32)b1) << 8) |  \
144      ((PRUint32)b0))
145 #else
146 #define WORD4(b0, b1, b2, b3) \
147     ((((PRUint32)b0) << 24) | \
148      (((PRUint32)b1) << 16) | \
149      (((PRUint32)b2) << 8) |  \
150      ((PRUint32)b3))
151 #endif
152 
153 /*
154  * Define the S and S**-1 tables (both have been stored)
155  */
156 #define SBOX(b) (_S[b])
157 #define SINV(b) (_SInv[b])
158 
159 /*
160  * The function xtime, used for Galois field multiplication
161  */
162 #define XTIME(a) \
163     ((a & 0x80) ? ((a << 1) ^ 0x1b) : (a << 1))
164 
165 /* Choose GFM method (macros or function) */
166 #if defined(RIJNDAEL_GENERATE_VALUES_MACRO)
167 
168 /*
169  * Galois field GF(2**8) multipliers, in macro form
170  */
171 #define GFM01(a) \
172     (a) /* a * 01 = a, the identity */
173 #define GFM02(a) \
174     (XTIME(a) & 0xff) /* a * 02 = xtime(a) */
175 #define GFM04(a) \
176     (GFM02(GFM02(a))) /* a * 04 = xtime**2(a) */
177 #define GFM08(a) \
178     (GFM02(GFM04(a))) /* a * 08 = xtime**3(a) */
179 #define GFM03(a) \
180     (GFM01(a) ^ GFM02(a)) /* a * 03 = a * (01 + 02) */
181 #define GFM09(a) \
182     (GFM01(a) ^ GFM08(a)) /* a * 09 = a * (01 + 08) */
183 #define GFM0B(a) \
184     (GFM01(a) ^ GFM02(a) ^ GFM08(a)) /* a * 0B = a * (01 + 02 + 08) */
185 #define GFM0D(a) \
186     (GFM01(a) ^ GFM04(a) ^ GFM08(a)) /* a * 0D = a * (01 + 04 + 08) */
187 #define GFM0E(a) \
188     (GFM02(a) ^ GFM04(a) ^ GFM08(a)) /* a * 0E = a * (02 + 04 + 08) */
189 
190 #else /* RIJNDAEL_GENERATE_VALUES */
191 
192 /* GF_MULTIPLY
193  *
194  * multiply two bytes represented in GF(2**8), mod (x**4 + 1)
195  */
196 PRUint8
gfm(PRUint8 a,PRUint8 b)197 gfm(PRUint8 a, PRUint8 b)
198 {
199     PRUint8 res = 0;
200     while (b > 0) {
201         res = (b & 0x01) ? res ^ a : res;
202         a = XTIME(a);
203         b >>= 1;
204     }
205     return res;
206 }
207 
208 #define GFM01(a) \
209     (a) /* a * 01 = a, the identity */
210 #define GFM02(a) \
211     (XTIME(a) & 0xff) /* a * 02 = xtime(a) */
212 #define GFM03(a) \
213     (gfm(a, 0x03)) /* a * 03 */
214 #define GFM09(a) \
215     (gfm(a, 0x09)) /* a * 09 */
216 #define GFM0B(a) \
217     (gfm(a, 0x0B)) /* a * 0B */
218 #define GFM0D(a) \
219     (gfm(a, 0x0D)) /* a * 0D */
220 #define GFM0E(a) \
221     (gfm(a, 0x0E)) /* a * 0E */
222 
223 #endif /* choosing GFM function */
224 
225 /*
226  * The T-tables
227  */
228 #define G_T0(i) \
229     (WORD4(GFM02(SBOX(i)), GFM01(SBOX(i)), GFM01(SBOX(i)), GFM03(SBOX(i))))
230 #define G_T1(i) \
231     (WORD4(GFM03(SBOX(i)), GFM02(SBOX(i)), GFM01(SBOX(i)), GFM01(SBOX(i))))
232 #define G_T2(i) \
233     (WORD4(GFM01(SBOX(i)), GFM03(SBOX(i)), GFM02(SBOX(i)), GFM01(SBOX(i))))
234 #define G_T3(i) \
235     (WORD4(GFM01(SBOX(i)), GFM01(SBOX(i)), GFM03(SBOX(i)), GFM02(SBOX(i))))
236 
237 /*
238  * The inverse T-tables
239  */
240 #define G_TInv0(i) \
241     (WORD4(GFM0E(SINV(i)), GFM09(SINV(i)), GFM0D(SINV(i)), GFM0B(SINV(i))))
242 #define G_TInv1(i) \
243     (WORD4(GFM0B(SINV(i)), GFM0E(SINV(i)), GFM09(SINV(i)), GFM0D(SINV(i))))
244 #define G_TInv2(i) \
245     (WORD4(GFM0D(SINV(i)), GFM0B(SINV(i)), GFM0E(SINV(i)), GFM09(SINV(i))))
246 #define G_TInv3(i) \
247     (WORD4(GFM09(SINV(i)), GFM0D(SINV(i)), GFM0B(SINV(i)), GFM0E(SINV(i))))
248 
249 /*
250  * The inverse mix column tables
251  */
252 #define G_IMXC0(i) \
253     (WORD4(GFM0E(i), GFM09(i), GFM0D(i), GFM0B(i)))
254 #define G_IMXC1(i) \
255     (WORD4(GFM0B(i), GFM0E(i), GFM09(i), GFM0D(i)))
256 #define G_IMXC2(i) \
257     (WORD4(GFM0D(i), GFM0B(i), GFM0E(i), GFM09(i)))
258 #define G_IMXC3(i) \
259     (WORD4(GFM09(i), GFM0D(i), GFM0B(i), GFM0E(i)))
260 
261 /* Now choose the T-table indexing method */
262 #if defined(RIJNDAEL_GENERATE_VALUES)
263 /* generate values for the tables with a function*/
264 static PRUint32
gen_TInvXi(PRUint8 tx,PRUint8 i)265 gen_TInvXi(PRUint8 tx, PRUint8 i)
266 {
267     PRUint8 si01, si02, si03, si04, si08, si09, si0B, si0D, si0E;
268     si01 = SINV(i);
269     si02 = XTIME(si01);
270     si04 = XTIME(si02);
271     si08 = XTIME(si04);
272     si03 = si02 ^ si01;
273     si09 = si08 ^ si01;
274     si0B = si08 ^ si03;
275     si0D = si09 ^ si04;
276     si0E = si08 ^ si04 ^ si02;
277     switch (tx) {
278         case 0:
279             return WORD4(si0E, si09, si0D, si0B);
280         case 1:
281             return WORD4(si0B, si0E, si09, si0D);
282         case 2:
283             return WORD4(si0D, si0B, si0E, si09);
284         case 3:
285             return WORD4(si09, si0D, si0B, si0E);
286     }
287     return -1;
288 }
289 #define T0(i) G_T0(i)
290 #define T1(i) G_T1(i)
291 #define T2(i) G_T2(i)
292 #define T3(i) G_T3(i)
293 #define TInv0(i) gen_TInvXi(0, i)
294 #define TInv1(i) gen_TInvXi(1, i)
295 #define TInv2(i) gen_TInvXi(2, i)
296 #define TInv3(i) gen_TInvXi(3, i)
297 #define IMXC0(b) G_IMXC0(b)
298 #define IMXC1(b) G_IMXC1(b)
299 #define IMXC2(b) G_IMXC2(b)
300 #define IMXC3(b) G_IMXC3(b)
301 #else /* RIJNDAEL_GENERATE_VALUES_MACRO */
302 /* generate values for the tables with macros */
303 #define T0(i) G_T0(i)
304 #define T1(i) G_T1(i)
305 #define T2(i) G_T2(i)
306 #define T3(i) G_T3(i)
307 #define TInv0(i) G_TInv0(i)
308 #define TInv1(i) G_TInv1(i)
309 #define TInv2(i) G_TInv2(i)
310 #define TInv3(i) G_TInv3(i)
311 #define IMXC0(b) G_IMXC0(b)
312 #define IMXC1(b) G_IMXC1(b)
313 #define IMXC2(b) G_IMXC2(b)
314 #define IMXC3(b) G_IMXC3(b)
315 #endif /* choose T-table indexing method */
316 
317 #endif /* not RIJNDAEL_INCLUDE_TABLES */
318 
319 /**************************************************************************
320  *
321  * Stuff related to the Rijndael key schedule
322  *
323  *************************************************************************/
324 
325 #define SUBBYTE(w)                                \
326     ((((PRUint32)SBOX((w >> 24) & 0xff)) << 24) | \
327      (((PRUint32)SBOX((w >> 16) & 0xff)) << 16) | \
328      (((PRUint32)SBOX((w >> 8) & 0xff)) << 8) |   \
329      (((PRUint32)SBOX((w)&0xff))))
330 
331 #ifdef IS_LITTLE_ENDIAN
332 #define ROTBYTE(b) \
333     ((b >> 8) | (b << 24))
334 #else
335 #define ROTBYTE(b) \
336     ((b << 8) | (b >> 24))
337 #endif
338 
339 /* rijndael_key_expansion7
340  *
341  * Generate the expanded key from the key input by the user.
342  * XXX
343  * Nk == 7 (224 key bits) is a weird case.  Since Nk > 6, an added SubByte
344  * transformation is done periodically.  The period is every 4 bytes, and
345  * since 7%4 != 0 this happens at different times for each key word (unlike
346  * Nk == 8 where it happens twice in every key word, in the same positions).
347  * For now, I'm implementing this case "dumbly", w/o any unrolling.
348  */
349 static void
rijndael_key_expansion7(AESContext * cx,const unsigned char * key,unsigned int Nk)350 rijndael_key_expansion7(AESContext *cx, const unsigned char *key, unsigned int Nk)
351 {
352     unsigned int i;
353     PRUint32 *W;
354     PRUint32 *pW;
355     PRUint32 tmp;
356     W = cx->k.expandedKey;
357     /* 1.  the first Nk words contain the cipher key */
358     memcpy(W, key, Nk * 4);
359     i = Nk;
360     /* 2.  loop until full expanded key is obtained */
361     pW = W + i - 1;
362     for (; i < cx->Nb * (cx->Nr + 1); ++i) {
363         tmp = *pW++;
364         if (i % Nk == 0)
365             tmp = SUBBYTE(ROTBYTE(tmp)) ^ Rcon[i / Nk - 1];
366         else if (i % Nk == 4)
367             tmp = SUBBYTE(tmp);
368         *pW = W[i - Nk] ^ tmp;
369     }
370 }
371 
372 /* rijndael_key_expansion
373  *
374  * Generate the expanded key from the key input by the user.
375  */
376 static void
rijndael_key_expansion(AESContext * cx,const unsigned char * key,unsigned int Nk)377 rijndael_key_expansion(AESContext *cx, const unsigned char *key, unsigned int Nk)
378 {
379     unsigned int i;
380     PRUint32 *W;
381     PRUint32 *pW;
382     PRUint32 tmp;
383     unsigned int round_key_words = cx->Nb * (cx->Nr + 1);
384     if (Nk == 7) {
385         rijndael_key_expansion7(cx, key, Nk);
386         return;
387     }
388     W = cx->k.expandedKey;
389     /* The first Nk words contain the input cipher key */
390     memcpy(W, key, Nk * 4);
391     i = Nk;
392     pW = W + i - 1;
393     /* Loop over all sets of Nk words, except the last */
394     while (i < round_key_words - Nk) {
395         tmp = *pW++;
396         tmp = SUBBYTE(ROTBYTE(tmp)) ^ Rcon[i / Nk - 1];
397         *pW = W[i++ - Nk] ^ tmp;
398         tmp = *pW++;
399         *pW = W[i++ - Nk] ^ tmp;
400         tmp = *pW++;
401         *pW = W[i++ - Nk] ^ tmp;
402         tmp = *pW++;
403         *pW = W[i++ - Nk] ^ tmp;
404         if (Nk == 4)
405             continue;
406         switch (Nk) {
407             case 8:
408                 tmp = *pW++;
409                 tmp = SUBBYTE(tmp);
410                 *pW = W[i++ - Nk] ^ tmp;
411             case 7:
412                 tmp = *pW++;
413                 *pW = W[i++ - Nk] ^ tmp;
414             case 6:
415                 tmp = *pW++;
416                 *pW = W[i++ - Nk] ^ tmp;
417             case 5:
418                 tmp = *pW++;
419                 *pW = W[i++ - Nk] ^ tmp;
420         }
421     }
422     /* Generate the last word */
423     tmp = *pW++;
424     tmp = SUBBYTE(ROTBYTE(tmp)) ^ Rcon[i / Nk - 1];
425     *pW = W[i++ - Nk] ^ tmp;
426     /* There may be overflow here, if Nk % (Nb * (Nr + 1)) > 0.  However,
427      * since the above loop generated all but the last Nk key words, there
428      * is no more need for the SubByte transformation.
429      */
430     if (Nk < 8) {
431         for (; i < round_key_words; ++i) {
432             tmp = *pW++;
433             *pW = W[i - Nk] ^ tmp;
434         }
435     } else {
436         /* except in the case when Nk == 8.  Then one more SubByte may have
437          * to be performed, at i % Nk == 4.
438          */
439         for (; i < round_key_words; ++i) {
440             tmp = *pW++;
441             if (i % Nk == 4)
442                 tmp = SUBBYTE(tmp);
443             *pW = W[i - Nk] ^ tmp;
444         }
445     }
446 }
447 
448 /* rijndael_invkey_expansion
449  *
450  * Generate the expanded key for the inverse cipher from the key input by
451  * the user.
452  */
453 static void
rijndael_invkey_expansion(AESContext * cx,const unsigned char * key,unsigned int Nk)454 rijndael_invkey_expansion(AESContext *cx, const unsigned char *key, unsigned int Nk)
455 {
456     unsigned int r;
457     PRUint32 *roundkeyw;
458     PRUint8 *b;
459     int Nb = cx->Nb;
460     /* begins like usual key expansion ... */
461     rijndael_key_expansion(cx, key, Nk);
462     /* ... but has the additional step of InvMixColumn,
463      * excepting the first and last round keys.
464      */
465     roundkeyw = cx->k.expandedKey + cx->Nb;
466     for (r = 1; r < cx->Nr; ++r) {
467         /* each key word, roundkeyw, represents a column in the key
468          * matrix.  Each column is multiplied by the InvMixColumn matrix.
469          *   [ 0E 0B 0D 09 ]   [ b0 ]
470          *   [ 09 0E 0B 0D ] * [ b1 ]
471          *   [ 0D 09 0E 0B ]   [ b2 ]
472          *   [ 0B 0D 09 0E ]   [ b3 ]
473          */
474         b = (PRUint8 *)roundkeyw;
475         *roundkeyw++ = IMXC0(b[0]) ^ IMXC1(b[1]) ^ IMXC2(b[2]) ^ IMXC3(b[3]);
476         b = (PRUint8 *)roundkeyw;
477         *roundkeyw++ = IMXC0(b[0]) ^ IMXC1(b[1]) ^ IMXC2(b[2]) ^ IMXC3(b[3]);
478         b = (PRUint8 *)roundkeyw;
479         *roundkeyw++ = IMXC0(b[0]) ^ IMXC1(b[1]) ^ IMXC2(b[2]) ^ IMXC3(b[3]);
480         b = (PRUint8 *)roundkeyw;
481         *roundkeyw++ = IMXC0(b[0]) ^ IMXC1(b[1]) ^ IMXC2(b[2]) ^ IMXC3(b[3]);
482         if (Nb <= 4)
483             continue;
484         switch (Nb) {
485             case 8:
486                 b = (PRUint8 *)roundkeyw;
487                 *roundkeyw++ = IMXC0(b[0]) ^ IMXC1(b[1]) ^
488                                IMXC2(b[2]) ^ IMXC3(b[3]);
489             case 7:
490                 b = (PRUint8 *)roundkeyw;
491                 *roundkeyw++ = IMXC0(b[0]) ^ IMXC1(b[1]) ^
492                                IMXC2(b[2]) ^ IMXC3(b[3]);
493             case 6:
494                 b = (PRUint8 *)roundkeyw;
495                 *roundkeyw++ = IMXC0(b[0]) ^ IMXC1(b[1]) ^
496                                IMXC2(b[2]) ^ IMXC3(b[3]);
497             case 5:
498                 b = (PRUint8 *)roundkeyw;
499                 *roundkeyw++ = IMXC0(b[0]) ^ IMXC1(b[1]) ^
500                                IMXC2(b[2]) ^ IMXC3(b[3]);
501         }
502     }
503 }
504 
505 /**************************************************************************
506  *
507  * Stuff related to Rijndael encryption/decryption.
508  *
509  *************************************************************************/
510 
511 #ifdef IS_LITTLE_ENDIAN
512 #define BYTE0WORD(w) ((w)&0x000000ff)
513 #define BYTE1WORD(w) ((w)&0x0000ff00)
514 #define BYTE2WORD(w) ((w)&0x00ff0000)
515 #define BYTE3WORD(w) ((w)&0xff000000)
516 #else
517 #define BYTE0WORD(w) ((w)&0xff000000)
518 #define BYTE1WORD(w) ((w)&0x00ff0000)
519 #define BYTE2WORD(w) ((w)&0x0000ff00)
520 #define BYTE3WORD(w) ((w)&0x000000ff)
521 #endif
522 
523 typedef union {
524     PRUint32 w[4];
525     PRUint8 b[16];
526 } rijndael_state;
527 
528 #define COLUMN_0(state) state.w[0]
529 #define COLUMN_1(state) state.w[1]
530 #define COLUMN_2(state) state.w[2]
531 #define COLUMN_3(state) state.w[3]
532 
533 #define STATE_BYTE(i) state.b[i]
534 
535 // out = a ^ b
536 inline static void
xorBlock(unsigned char * out,const unsigned char * a,const unsigned char * b)537 xorBlock(unsigned char *out, const unsigned char *a, const unsigned char *b)
538 {
539     for (unsigned int j = 0; j < AES_BLOCK_SIZE; ++j) {
540         (out)[j] = (a)[j] ^ (b)[j];
541     }
542 }
543 
544 static void NO_SANITIZE_ALIGNMENT
rijndael_encryptBlock128(AESContext * cx,unsigned char * output,const unsigned char * input)545 rijndael_encryptBlock128(AESContext *cx,
546                          unsigned char *output,
547                          const unsigned char *input)
548 {
549     unsigned int r;
550     PRUint32 *roundkeyw;
551     rijndael_state state;
552     PRUint32 C0, C1, C2, C3;
553 #if defined(NSS_X86_OR_X64)
554 #define pIn input
555 #define pOut output
556 #else
557     unsigned char *pIn, *pOut;
558     PRUint32 inBuf[4], outBuf[4];
559 
560     if ((ptrdiff_t)input & 0x3) {
561         memcpy(inBuf, input, sizeof inBuf);
562         pIn = (unsigned char *)inBuf;
563     } else {
564         pIn = (unsigned char *)input;
565     }
566     if ((ptrdiff_t)output & 0x3) {
567         pOut = (unsigned char *)outBuf;
568     } else {
569         pOut = (unsigned char *)output;
570     }
571 #endif
572     roundkeyw = cx->k.expandedKey;
573     /* Step 1: Add Round Key 0 to initial state */
574     COLUMN_0(state) = *((PRUint32 *)(pIn)) ^ *roundkeyw++;
575     COLUMN_1(state) = *((PRUint32 *)(pIn + 4)) ^ *roundkeyw++;
576     COLUMN_2(state) = *((PRUint32 *)(pIn + 8)) ^ *roundkeyw++;
577     COLUMN_3(state) = *((PRUint32 *)(pIn + 12)) ^ *roundkeyw++;
578     /* Step 2: Loop over rounds [1..NR-1] */
579     for (r = 1; r < cx->Nr; ++r) {
580         /* Do ShiftRow, ByteSub, and MixColumn all at once */
581         C0 = T0(STATE_BYTE(0)) ^
582              T1(STATE_BYTE(5)) ^
583              T2(STATE_BYTE(10)) ^
584              T3(STATE_BYTE(15));
585         C1 = T0(STATE_BYTE(4)) ^
586              T1(STATE_BYTE(9)) ^
587              T2(STATE_BYTE(14)) ^
588              T3(STATE_BYTE(3));
589         C2 = T0(STATE_BYTE(8)) ^
590              T1(STATE_BYTE(13)) ^
591              T2(STATE_BYTE(2)) ^
592              T3(STATE_BYTE(7));
593         C3 = T0(STATE_BYTE(12)) ^
594              T1(STATE_BYTE(1)) ^
595              T2(STATE_BYTE(6)) ^
596              T3(STATE_BYTE(11));
597         /* Round key addition */
598         COLUMN_0(state) = C0 ^ *roundkeyw++;
599         COLUMN_1(state) = C1 ^ *roundkeyw++;
600         COLUMN_2(state) = C2 ^ *roundkeyw++;
601         COLUMN_3(state) = C3 ^ *roundkeyw++;
602     }
603     /* Step 3: Do the last round */
604     /* Final round does not employ MixColumn */
605     C0 = ((BYTE0WORD(T2(STATE_BYTE(0)))) |
606           (BYTE1WORD(T3(STATE_BYTE(5)))) |
607           (BYTE2WORD(T0(STATE_BYTE(10)))) |
608           (BYTE3WORD(T1(STATE_BYTE(15))))) ^
609          *roundkeyw++;
610     C1 = ((BYTE0WORD(T2(STATE_BYTE(4)))) |
611           (BYTE1WORD(T3(STATE_BYTE(9)))) |
612           (BYTE2WORD(T0(STATE_BYTE(14)))) |
613           (BYTE3WORD(T1(STATE_BYTE(3))))) ^
614          *roundkeyw++;
615     C2 = ((BYTE0WORD(T2(STATE_BYTE(8)))) |
616           (BYTE1WORD(T3(STATE_BYTE(13)))) |
617           (BYTE2WORD(T0(STATE_BYTE(2)))) |
618           (BYTE3WORD(T1(STATE_BYTE(7))))) ^
619          *roundkeyw++;
620     C3 = ((BYTE0WORD(T2(STATE_BYTE(12)))) |
621           (BYTE1WORD(T3(STATE_BYTE(1)))) |
622           (BYTE2WORD(T0(STATE_BYTE(6)))) |
623           (BYTE3WORD(T1(STATE_BYTE(11))))) ^
624          *roundkeyw++;
625     *((PRUint32 *)pOut) = C0;
626     *((PRUint32 *)(pOut + 4)) = C1;
627     *((PRUint32 *)(pOut + 8)) = C2;
628     *((PRUint32 *)(pOut + 12)) = C3;
629 #if defined(NSS_X86_OR_X64)
630 #undef pIn
631 #undef pOut
632 #else
633     if ((ptrdiff_t)output & 0x3) {
634         memcpy(output, outBuf, sizeof outBuf);
635     }
636 #endif
637 }
638 
639 static void NO_SANITIZE_ALIGNMENT
rijndael_decryptBlock128(AESContext * cx,unsigned char * output,const unsigned char * input)640 rijndael_decryptBlock128(AESContext *cx,
641                          unsigned char *output,
642                          const unsigned char *input)
643 {
644     int r;
645     PRUint32 *roundkeyw;
646     rijndael_state state;
647     PRUint32 C0, C1, C2, C3;
648 #if defined(NSS_X86_OR_X64)
649 #define pIn input
650 #define pOut output
651 #else
652     unsigned char *pIn, *pOut;
653     PRUint32 inBuf[4], outBuf[4];
654 
655     if ((ptrdiff_t)input & 0x3) {
656         memcpy(inBuf, input, sizeof inBuf);
657         pIn = (unsigned char *)inBuf;
658     } else {
659         pIn = (unsigned char *)input;
660     }
661     if ((ptrdiff_t)output & 0x3) {
662         pOut = (unsigned char *)outBuf;
663     } else {
664         pOut = (unsigned char *)output;
665     }
666 #endif
667     roundkeyw = cx->k.expandedKey + cx->Nb * cx->Nr + 3;
668     /* reverse the final key addition */
669     COLUMN_3(state) = *((PRUint32 *)(pIn + 12)) ^ *roundkeyw--;
670     COLUMN_2(state) = *((PRUint32 *)(pIn + 8)) ^ *roundkeyw--;
671     COLUMN_1(state) = *((PRUint32 *)(pIn + 4)) ^ *roundkeyw--;
672     COLUMN_0(state) = *((PRUint32 *)(pIn)) ^ *roundkeyw--;
673     /* Loop over rounds in reverse [NR..1] */
674     for (r = cx->Nr; r > 1; --r) {
675         /* Invert the (InvByteSub*InvMixColumn)(InvShiftRow(state)) */
676         C0 = TInv0(STATE_BYTE(0)) ^
677              TInv1(STATE_BYTE(13)) ^
678              TInv2(STATE_BYTE(10)) ^
679              TInv3(STATE_BYTE(7));
680         C1 = TInv0(STATE_BYTE(4)) ^
681              TInv1(STATE_BYTE(1)) ^
682              TInv2(STATE_BYTE(14)) ^
683              TInv3(STATE_BYTE(11));
684         C2 = TInv0(STATE_BYTE(8)) ^
685              TInv1(STATE_BYTE(5)) ^
686              TInv2(STATE_BYTE(2)) ^
687              TInv3(STATE_BYTE(15));
688         C3 = TInv0(STATE_BYTE(12)) ^
689              TInv1(STATE_BYTE(9)) ^
690              TInv2(STATE_BYTE(6)) ^
691              TInv3(STATE_BYTE(3));
692         /* Invert the key addition step */
693         COLUMN_3(state) = C3 ^ *roundkeyw--;
694         COLUMN_2(state) = C2 ^ *roundkeyw--;
695         COLUMN_1(state) = C1 ^ *roundkeyw--;
696         COLUMN_0(state) = C0 ^ *roundkeyw--;
697     }
698     /* inverse sub */
699     pOut[0] = SINV(STATE_BYTE(0));
700     pOut[1] = SINV(STATE_BYTE(13));
701     pOut[2] = SINV(STATE_BYTE(10));
702     pOut[3] = SINV(STATE_BYTE(7));
703     pOut[4] = SINV(STATE_BYTE(4));
704     pOut[5] = SINV(STATE_BYTE(1));
705     pOut[6] = SINV(STATE_BYTE(14));
706     pOut[7] = SINV(STATE_BYTE(11));
707     pOut[8] = SINV(STATE_BYTE(8));
708     pOut[9] = SINV(STATE_BYTE(5));
709     pOut[10] = SINV(STATE_BYTE(2));
710     pOut[11] = SINV(STATE_BYTE(15));
711     pOut[12] = SINV(STATE_BYTE(12));
712     pOut[13] = SINV(STATE_BYTE(9));
713     pOut[14] = SINV(STATE_BYTE(6));
714     pOut[15] = SINV(STATE_BYTE(3));
715     /* final key addition */
716     *((PRUint32 *)(pOut + 12)) ^= *roundkeyw--;
717     *((PRUint32 *)(pOut + 8)) ^= *roundkeyw--;
718     *((PRUint32 *)(pOut + 4)) ^= *roundkeyw--;
719     *((PRUint32 *)pOut) ^= *roundkeyw--;
720 #if defined(NSS_X86_OR_X64)
721 #undef pIn
722 #undef pOut
723 #else
724     if ((ptrdiff_t)output & 0x3) {
725         memcpy(output, outBuf, sizeof outBuf);
726     }
727 #endif
728 }
729 
730 /**************************************************************************
731  *
732  *  Rijndael modes of operation (ECB and CBC)
733  *
734  *************************************************************************/
735 
736 static SECStatus
rijndael_encryptECB(AESContext * cx,unsigned char * output,unsigned int * outputLen,unsigned int maxOutputLen,const unsigned char * input,unsigned int inputLen)737 rijndael_encryptECB(AESContext *cx, unsigned char *output,
738                     unsigned int *outputLen, unsigned int maxOutputLen,
739                     const unsigned char *input, unsigned int inputLen)
740 {
741     PRBool aesni = aesni_support();
742     while (inputLen > 0) {
743         if (aesni) {
744             rijndael_native_encryptBlock(cx, output, input);
745         } else {
746             rijndael_encryptBlock128(cx, output, input);
747         }
748         output += AES_BLOCK_SIZE;
749         input += AES_BLOCK_SIZE;
750         inputLen -= AES_BLOCK_SIZE;
751     }
752     return SECSuccess;
753 }
754 
755 static SECStatus
rijndael_encryptCBC(AESContext * cx,unsigned char * output,unsigned int * outputLen,unsigned int maxOutputLen,const unsigned char * input,unsigned int inputLen)756 rijndael_encryptCBC(AESContext *cx, unsigned char *output,
757                     unsigned int *outputLen, unsigned int maxOutputLen,
758                     const unsigned char *input, unsigned int inputLen)
759 {
760     unsigned char *lastblock = cx->iv;
761     unsigned char inblock[AES_BLOCK_SIZE * 8];
762     PRBool aesni = aesni_support();
763 
764     if (!inputLen)
765         return SECSuccess;
766     while (inputLen > 0) {
767         if (aesni) {
768             /* XOR with the last block (IV if first block) */
769             native_xorBlock(inblock, input, lastblock);
770             /* encrypt */
771             rijndael_native_encryptBlock(cx, output, inblock);
772         } else {
773             xorBlock(inblock, input, lastblock);
774             rijndael_encryptBlock128(cx, output, inblock);
775         }
776 
777         /* move to the next block */
778         lastblock = output;
779         output += AES_BLOCK_SIZE;
780         input += AES_BLOCK_SIZE;
781         inputLen -= AES_BLOCK_SIZE;
782     }
783     memcpy(cx->iv, lastblock, AES_BLOCK_SIZE);
784     return SECSuccess;
785 }
786 
787 static SECStatus
rijndael_decryptECB(AESContext * cx,unsigned char * output,unsigned int * outputLen,unsigned int maxOutputLen,const unsigned char * input,unsigned int inputLen)788 rijndael_decryptECB(AESContext *cx, unsigned char *output,
789                     unsigned int *outputLen, unsigned int maxOutputLen,
790                     const unsigned char *input, unsigned int inputLen)
791 {
792     PRBool aesni = aesni_support();
793     while (inputLen > 0) {
794         if (aesni) {
795             rijndael_native_decryptBlock(cx, output, input);
796         } else {
797             rijndael_decryptBlock128(cx, output, input);
798         }
799         output += AES_BLOCK_SIZE;
800         input += AES_BLOCK_SIZE;
801         inputLen -= AES_BLOCK_SIZE;
802     }
803     return SECSuccess;
804 }
805 
806 static SECStatus
rijndael_decryptCBC(AESContext * cx,unsigned char * output,unsigned int * outputLen,unsigned int maxOutputLen,const unsigned char * input,unsigned int inputLen)807 rijndael_decryptCBC(AESContext *cx, unsigned char *output,
808                     unsigned int *outputLen, unsigned int maxOutputLen,
809                     const unsigned char *input, unsigned int inputLen)
810 {
811     const unsigned char *in;
812     unsigned char *out;
813     unsigned char newIV[AES_BLOCK_SIZE];
814     PRBool aesni = aesni_support();
815 
816     if (!inputLen)
817         return SECSuccess;
818     PORT_Assert(output - input >= 0 || input - output >= (int)inputLen);
819     in = input + (inputLen - AES_BLOCK_SIZE);
820     memcpy(newIV, in, AES_BLOCK_SIZE);
821     out = output + (inputLen - AES_BLOCK_SIZE);
822     while (inputLen > AES_BLOCK_SIZE) {
823         if (aesni) {
824             // Use hardware acceleration for normal AES parameters.
825             rijndael_native_decryptBlock(cx, out, in);
826             native_xorBlock(out, out, &in[-AES_BLOCK_SIZE]);
827         } else {
828             rijndael_decryptBlock128(cx, out, in);
829             xorBlock(out, out, &in[-AES_BLOCK_SIZE]);
830         }
831         out -= AES_BLOCK_SIZE;
832         in -= AES_BLOCK_SIZE;
833         inputLen -= AES_BLOCK_SIZE;
834     }
835     if (in == input) {
836         if (aesni) {
837             rijndael_native_decryptBlock(cx, out, in);
838             native_xorBlock(out, out, cx->iv);
839         } else {
840             rijndael_decryptBlock128(cx, out, in);
841             xorBlock(out, out, cx->iv);
842         }
843     }
844     memcpy(cx->iv, newIV, AES_BLOCK_SIZE);
845     return SECSuccess;
846 }
847 
848 /************************************************************************
849  *
850  * BLAPI Interface functions
851  *
852  * The following functions implement the encryption routines defined in
853  * BLAPI for the AES cipher, Rijndael.
854  *
855  ***********************************************************************/
856 
857 AESContext *
AES_AllocateContext(void)858 AES_AllocateContext(void)
859 {
860     return PORT_ZNewAligned(AESContext, 16, mem);
861 }
862 
863 /*
864 ** Initialize a new AES context suitable for AES encryption/decryption in
865 ** the ECB or CBC mode.
866 **  "mode" the mode of operation, which must be NSS_AES or NSS_AES_CBC
867 */
868 static SECStatus
aes_InitContext(AESContext * cx,const unsigned char * key,unsigned int keysize,const unsigned char * iv,int mode,unsigned int encrypt)869 aes_InitContext(AESContext *cx, const unsigned char *key, unsigned int keysize,
870                 const unsigned char *iv, int mode, unsigned int encrypt)
871 {
872     unsigned int Nk;
873     PRBool use_hw_aes;
874     /* According to AES, block lengths are 128 and key lengths are 128, 192, or
875      * 256 bits. We support other key sizes as well [128, 256] as long as the
876      * length in bytes is divisible by 4.
877      */
878 
879     if (key == NULL ||
880         keysize < AES_BLOCK_SIZE ||
881         keysize > 32 ||
882         keysize % 4 != 0) {
883         PORT_SetError(SEC_ERROR_INVALID_ARGS);
884         return SECFailure;
885     }
886     if (mode != NSS_AES && mode != NSS_AES_CBC) {
887         PORT_SetError(SEC_ERROR_INVALID_ARGS);
888         return SECFailure;
889     }
890     if (mode == NSS_AES_CBC && iv == NULL) {
891         PORT_SetError(SEC_ERROR_INVALID_ARGS);
892         return SECFailure;
893     }
894     if (!cx) {
895         PORT_SetError(SEC_ERROR_INVALID_ARGS);
896         return SECFailure;
897     }
898 #if defined(NSS_X86_OR_X64) || defined(USE_HW_AES)
899     use_hw_aes = (aesni_support() || arm_aes_support()) && (keysize % 8) == 0;
900 #else
901     use_hw_aes = PR_FALSE;
902 #endif
903     /* Nb = (block size in bits) / 32 */
904     cx->Nb = AES_BLOCK_SIZE / 4;
905     /* Nk = (key size in bits) / 32 */
906     Nk = keysize / 4;
907     /* Obtain number of rounds from "table" */
908     cx->Nr = RIJNDAEL_NUM_ROUNDS(Nk, cx->Nb);
909     /* copy in the iv, if neccessary */
910     if (mode == NSS_AES_CBC) {
911         memcpy(cx->iv, iv, AES_BLOCK_SIZE);
912 #ifdef USE_HW_AES
913         if (use_hw_aes) {
914             cx->worker = (freeblCipherFunc)
915                 native_aes_cbc_worker(encrypt, keysize);
916         } else
917 #endif
918         {
919             cx->worker = (freeblCipherFunc)(encrypt
920                                                 ? &rijndael_encryptCBC
921                                                 : &rijndael_decryptCBC);
922         }
923     } else {
924 #ifdef USE_HW_AES
925         if (use_hw_aes) {
926             cx->worker = (freeblCipherFunc)
927                 native_aes_ecb_worker(encrypt, keysize);
928         } else
929 #endif
930         {
931             cx->worker = (freeblCipherFunc)(encrypt
932                                                 ? &rijndael_encryptECB
933                                                 : &rijndael_decryptECB);
934         }
935     }
936     PORT_Assert((cx->Nb * (cx->Nr + 1)) <= RIJNDAEL_MAX_EXP_KEY_SIZE);
937     if ((cx->Nb * (cx->Nr + 1)) > RIJNDAEL_MAX_EXP_KEY_SIZE) {
938         PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
939         return SECFailure;
940     }
941 #ifdef USE_HW_AES
942     if (use_hw_aes) {
943         native_aes_init(encrypt, keysize);
944     } else
945 #endif
946     {
947         /* Generate expanded key */
948         if (encrypt) {
949             if (use_hw_aes && (cx->mode == NSS_AES_GCM || cx->mode == NSS_AES ||
950                                cx->mode == NSS_AES_CTR)) {
951                 PORT_Assert(keysize == 16 || keysize == 24 || keysize == 32);
952                 /* Prepare hardware key for normal AES parameters. */
953                 rijndael_native_key_expansion(cx, key, Nk);
954             } else {
955                 rijndael_key_expansion(cx, key, Nk);
956             }
957         } else {
958             rijndael_invkey_expansion(cx, key, Nk);
959         }
960     }
961     cx->worker_cx = cx;
962     cx->destroy = NULL;
963     cx->isBlock = PR_TRUE;
964     return SECSuccess;
965 }
966 
967 SECStatus
AES_InitContext(AESContext * cx,const unsigned char * key,unsigned int keysize,const unsigned char * iv,int mode,unsigned int encrypt,unsigned int blocksize)968 AES_InitContext(AESContext *cx, const unsigned char *key, unsigned int keysize,
969                 const unsigned char *iv, int mode, unsigned int encrypt,
970                 unsigned int blocksize)
971 {
972     int basemode = mode;
973     PRBool baseencrypt = encrypt;
974     SECStatus rv;
975 
976     if (blocksize != AES_BLOCK_SIZE) {
977         PORT_SetError(SEC_ERROR_INVALID_ARGS);
978         return SECFailure;
979     }
980 
981     switch (mode) {
982         case NSS_AES_CTS:
983             basemode = NSS_AES_CBC;
984             break;
985         case NSS_AES_GCM:
986         case NSS_AES_CTR:
987             basemode = NSS_AES;
988             baseencrypt = PR_TRUE;
989             break;
990     }
991     /* Make sure enough is initialized so we can safely call Destroy. */
992     cx->worker_cx = NULL;
993     cx->destroy = NULL;
994     cx->mode = mode;
995     rv = aes_InitContext(cx, key, keysize, iv, basemode, baseencrypt);
996     if (rv != SECSuccess) {
997         AES_DestroyContext(cx, PR_FALSE);
998         return rv;
999     }
1000 
1001     /* finally, set up any mode specific contexts */
1002     cx->worker_aead = 0;
1003     switch (mode) {
1004         case NSS_AES_CTS:
1005             cx->worker_cx = CTS_CreateContext(cx, cx->worker, iv);
1006             cx->worker = (freeblCipherFunc)(encrypt ? CTS_EncryptUpdate : CTS_DecryptUpdate);
1007             cx->destroy = (freeblDestroyFunc)CTS_DestroyContext;
1008             cx->isBlock = PR_FALSE;
1009             break;
1010         case NSS_AES_GCM:
1011 #if defined(INTEL_GCM) && defined(USE_HW_AES)
1012             if (aesni_support() && (keysize % 8) == 0 && avx_support() &&
1013                 clmul_support()) {
1014                 cx->worker_cx = intel_AES_GCM_CreateContext(cx, cx->worker, iv);
1015                 cx->worker = (freeblCipherFunc)(encrypt ? intel_AES_GCM_EncryptUpdate
1016                                                         : intel_AES_GCM_DecryptUpdate);
1017                 cx->worker_aead = (freeblAeadFunc)(encrypt ? intel_AES_GCM_EncryptAEAD
1018                                                            : intel_AES_GCM_DecryptAEAD);
1019                 cx->destroy = (freeblDestroyFunc)intel_AES_GCM_DestroyContext;
1020                 cx->isBlock = PR_FALSE;
1021             } else
1022 #endif
1023             {
1024                 cx->worker_cx = GCM_CreateContext(cx, cx->worker, iv);
1025                 cx->worker = (freeblCipherFunc)(encrypt ? GCM_EncryptUpdate
1026                                                         : GCM_DecryptUpdate);
1027                 cx->worker_aead = (freeblAeadFunc)(encrypt ? GCM_EncryptAEAD
1028                                                            : GCM_DecryptAEAD);
1029 
1030                 cx->destroy = (freeblDestroyFunc)GCM_DestroyContext;
1031                 cx->isBlock = PR_FALSE;
1032             }
1033             break;
1034         case NSS_AES_CTR:
1035             cx->worker_cx = CTR_CreateContext(cx, cx->worker, iv);
1036 #if defined(USE_HW_AES) && defined(_MSC_VER) && defined(NSS_X86_OR_X64)
1037             if (aesni_support() && (keysize % 8) == 0) {
1038                 cx->worker = (freeblCipherFunc)CTR_Update_HW_AES;
1039             } else
1040 #endif
1041             {
1042                 cx->worker = (freeblCipherFunc)CTR_Update;
1043             }
1044             cx->destroy = (freeblDestroyFunc)CTR_DestroyContext;
1045             cx->isBlock = PR_FALSE;
1046             break;
1047         default:
1048             /* everything has already been set up by aes_InitContext, just
1049              * return */
1050             return SECSuccess;
1051     }
1052     /* check to see if we succeeded in getting the worker context */
1053     if (cx->worker_cx == NULL) {
1054         /* no, just destroy the existing context */
1055         cx->destroy = NULL; /* paranoia, though you can see a dozen lines */
1056                             /* below that this isn't necessary */
1057         AES_DestroyContext(cx, PR_FALSE);
1058         return SECFailure;
1059     }
1060     return SECSuccess;
1061 }
1062 
1063 /* AES_CreateContext
1064  *
1065  * create a new context for Rijndael operations
1066  */
1067 AESContext *
AES_CreateContext(const unsigned char * key,const unsigned char * iv,int mode,int encrypt,unsigned int keysize,unsigned int blocksize)1068 AES_CreateContext(const unsigned char *key, const unsigned char *iv,
1069                   int mode, int encrypt,
1070                   unsigned int keysize, unsigned int blocksize)
1071 {
1072     AESContext *cx = AES_AllocateContext();
1073     if (cx) {
1074         SECStatus rv = AES_InitContext(cx, key, keysize, iv, mode, encrypt,
1075                                        blocksize);
1076         if (rv != SECSuccess) {
1077             AES_DestroyContext(cx, PR_TRUE);
1078             cx = NULL;
1079         }
1080     }
1081     return cx;
1082 }
1083 
1084 /*
1085  * AES_DestroyContext
1086  *
1087  * Zero an AES cipher context.  If freeit is true, also free the pointer
1088  * to the context.
1089  */
1090 void
AES_DestroyContext(AESContext * cx,PRBool freeit)1091 AES_DestroyContext(AESContext *cx, PRBool freeit)
1092 {
1093     void *mem = cx->mem;
1094     if (cx->worker_cx && cx->destroy) {
1095         (*cx->destroy)(cx->worker_cx, PR_TRUE);
1096         cx->worker_cx = NULL;
1097         cx->destroy = NULL;
1098     }
1099     PORT_Memset(cx, 0, sizeof(AESContext));
1100     if (freeit) {
1101         PORT_Free(mem);
1102     } else {
1103         /* if we are not freeing the context, restore mem, We may get called
1104          * again to actually free the context */
1105         cx->mem = mem;
1106     }
1107 }
1108 
1109 /*
1110  * AES_Encrypt
1111  *
1112  * Encrypt an arbitrary-length buffer.  The output buffer must already be
1113  * allocated to at least inputLen.
1114  */
1115 SECStatus
AES_Encrypt(AESContext * cx,unsigned char * output,unsigned int * outputLen,unsigned int maxOutputLen,const unsigned char * input,unsigned int inputLen)1116 AES_Encrypt(AESContext *cx, unsigned char *output,
1117             unsigned int *outputLen, unsigned int maxOutputLen,
1118             const unsigned char *input, unsigned int inputLen)
1119 {
1120     /* Check args */
1121     if (cx == NULL || output == NULL || (input == NULL && inputLen != 0)) {
1122         PORT_SetError(SEC_ERROR_INVALID_ARGS);
1123         return SECFailure;
1124     }
1125     if (cx->isBlock && (inputLen % AES_BLOCK_SIZE != 0)) {
1126         PORT_SetError(SEC_ERROR_INPUT_LEN);
1127         return SECFailure;
1128     }
1129     if (maxOutputLen < inputLen) {
1130         PORT_SetError(SEC_ERROR_OUTPUT_LEN);
1131         return SECFailure;
1132     }
1133     *outputLen = inputLen;
1134 #if UINT_MAX > MP_32BIT_MAX
1135     /*
1136      * we can guarentee that GSM won't overlfow if we limit the input to
1137      * 2^36 bytes. For simplicity, we are limiting it to 2^32 for now.
1138      *
1139      * We do it here to cover both hardware and software GCM operations.
1140      */
1141     {
1142         PR_STATIC_ASSERT(sizeof(unsigned int) > 4);
1143     }
1144     if ((cx->mode == NSS_AES_GCM) && (inputLen > MP_32BIT_MAX)) {
1145         PORT_SetError(SEC_ERROR_OUTPUT_LEN);
1146         return SECFailure;
1147     }
1148 #else
1149     /* if we can't pass in a 32_bit number, then no such check needed */
1150     {
1151         PR_STATIC_ASSERT(sizeof(unsigned int) <= 4);
1152     }
1153 #endif
1154 
1155     return (*cx->worker)(cx->worker_cx, output, outputLen, maxOutputLen,
1156                          input, inputLen, AES_BLOCK_SIZE);
1157 }
1158 
1159 /*
1160  * AES_Decrypt
1161  *
1162  * Decrypt and arbitrary-length buffer.  The output buffer must already be
1163  * allocated to at least inputLen.
1164  */
1165 SECStatus
AES_Decrypt(AESContext * cx,unsigned char * output,unsigned int * outputLen,unsigned int maxOutputLen,const unsigned char * input,unsigned int inputLen)1166 AES_Decrypt(AESContext *cx, unsigned char *output,
1167             unsigned int *outputLen, unsigned int maxOutputLen,
1168             const unsigned char *input, unsigned int inputLen)
1169 {
1170     /* Check args */
1171     if (cx == NULL || output == NULL || (input == NULL && inputLen != 0)) {
1172         PORT_SetError(SEC_ERROR_INVALID_ARGS);
1173         return SECFailure;
1174     }
1175     if (cx->isBlock && (inputLen % AES_BLOCK_SIZE != 0)) {
1176         PORT_SetError(SEC_ERROR_INPUT_LEN);
1177         return SECFailure;
1178     }
1179     if ((cx->mode != NSS_AES_GCM) && (maxOutputLen < inputLen)) {
1180         PORT_SetError(SEC_ERROR_OUTPUT_LEN);
1181         return SECFailure;
1182     }
1183     *outputLen = inputLen;
1184     return (*cx->worker)(cx->worker_cx, output, outputLen, maxOutputLen,
1185                          input, inputLen, AES_BLOCK_SIZE);
1186 }
1187 
1188 /*
1189  * AES_Encrypt_AEAD
1190  *
1191  * Encrypt using GCM or CCM. include the nonce, extra data, and the tag
1192  */
1193 SECStatus
AES_AEAD(AESContext * cx,unsigned char * output,unsigned int * outputLen,unsigned int maxOutputLen,const unsigned char * input,unsigned int inputLen,void * params,unsigned int paramsLen,const unsigned char * aad,unsigned int aadLen)1194 AES_AEAD(AESContext *cx, unsigned char *output,
1195          unsigned int *outputLen, unsigned int maxOutputLen,
1196          const unsigned char *input, unsigned int inputLen,
1197          void *params, unsigned int paramsLen,
1198          const unsigned char *aad, unsigned int aadLen)
1199 {
1200     /* Check args */
1201     if (cx == NULL || output == NULL || (input == NULL && inputLen != 0) || (aad == NULL && aadLen != 0) || params == NULL) {
1202         PORT_SetError(SEC_ERROR_INVALID_ARGS);
1203         return SECFailure;
1204     }
1205     if (cx->worker_aead == NULL) {
1206         PORT_SetError(SEC_ERROR_NOT_INITIALIZED);
1207         return SECFailure;
1208     }
1209     if (maxOutputLen < inputLen) {
1210         PORT_SetError(SEC_ERROR_OUTPUT_LEN);
1211         return SECFailure;
1212     }
1213     *outputLen = inputLen;
1214 #if UINT_MAX > MP_32BIT_MAX
1215     /*
1216      * we can guarentee that GSM won't overlfow if we limit the input to
1217      * 2^36 bytes. For simplicity, we are limiting it to 2^32 for now.
1218      *
1219      * We do it here to cover both hardware and software GCM operations.
1220      */
1221     {
1222         PR_STATIC_ASSERT(sizeof(unsigned int) > 4);
1223     }
1224     if (inputLen > MP_32BIT_MAX) {
1225         PORT_SetError(SEC_ERROR_OUTPUT_LEN);
1226         return SECFailure;
1227     }
1228 #else
1229     /* if we can't pass in a 32_bit number, then no such check needed */
1230     {
1231         PR_STATIC_ASSERT(sizeof(unsigned int) <= 4);
1232     }
1233 #endif
1234 
1235     return (*cx->worker_aead)(cx->worker_cx, output, outputLen, maxOutputLen,
1236                               input, inputLen, params, paramsLen, aad, aadLen,
1237                               AES_BLOCK_SIZE);
1238 }
1239