1 /* This Source Code Form is subject to the terms of the Mozilla Public
2  * License, v. 2.0. If a copy of the MPL was not distributed with this
3  * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
4 
5 #ifdef FREEBL_NO_DEPEND
6 #include "stubs.h"
7 #endif
8 
9 #include "prinit.h"
10 #include "prenv.h"
11 #include "prerr.h"
12 #include "secerr.h"
13 
14 #include "prtypes.h"
15 #include "blapi.h"
16 #include "rijndael.h"
17 
18 #include "cts.h"
19 #include "ctr.h"
20 #include "gcm.h"
21 #include "mpi.h"
22 
23 #if !defined(IS_LITTLE_ENDIAN) && !defined(NSS_X86_OR_X64)
24 // not test yet on big endian platform of arm
25 #undef USE_HW_AES
26 #endif
27 
28 #ifdef __powerpc64__
29 #include "ppc-crypto.h"
30 #endif
31 
32 #ifdef USE_HW_AES
33 #ifdef NSS_X86_OR_X64
34 #include "intel-aes.h"
35 #else
36 #include "aes-armv8.h"
37 #endif
38 #endif /* USE_HW_AES */
39 #ifdef INTEL_GCM
40 #include "intel-gcm.h"
41 #endif /* INTEL_GCM */
42 #if defined(USE_PPC_CRYPTO) && defined(PPC_GCM)
43 #include "ppc-gcm.h"
44 #endif
45 
46 /* Forward declarations */
47 void rijndael_native_key_expansion(AESContext *cx, const unsigned char *key,
48                                    unsigned int Nk);
49 void rijndael_native_encryptBlock(AESContext *cx,
50                                   unsigned char *output,
51                                   const unsigned char *input);
52 void rijndael_native_decryptBlock(AESContext *cx,
53                                   unsigned char *output,
54                                   const unsigned char *input);
55 void native_xorBlock(unsigned char *out,
56                      const unsigned char *a,
57                      const unsigned char *b);
58 
59 /* Stub definitions for the above rijndael_native_* functions, which
60  * shouldn't be used unless NSS_X86_OR_X64 is defined */
61 #ifndef NSS_X86_OR_X64
62 void
rijndael_native_key_expansion(AESContext * cx,const unsigned char * key,unsigned int Nk)63 rijndael_native_key_expansion(AESContext *cx, const unsigned char *key,
64                               unsigned int Nk)
65 {
66     PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
67     PORT_Assert(0);
68 }
69 
70 void
rijndael_native_encryptBlock(AESContext * cx,unsigned char * output,const unsigned char * input)71 rijndael_native_encryptBlock(AESContext *cx,
72                              unsigned char *output,
73                              const unsigned char *input)
74 {
75     PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
76     PORT_Assert(0);
77 }
78 
79 void
rijndael_native_decryptBlock(AESContext * cx,unsigned char * output,const unsigned char * input)80 rijndael_native_decryptBlock(AESContext *cx,
81                              unsigned char *output,
82                              const unsigned char *input)
83 {
84     PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
85     PORT_Assert(0);
86 }
87 
88 void
native_xorBlock(unsigned char * out,const unsigned char * a,const unsigned char * b)89 native_xorBlock(unsigned char *out, const unsigned char *a,
90                 const unsigned char *b)
91 {
92     PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
93     PORT_Assert(0);
94 }
95 #endif /* NSS_X86_OR_X64 */
96 
97 /*
98  * There are currently three ways to build this code, varying in performance
99  * and code size.
100  *
101  * RIJNDAEL_INCLUDE_TABLES         Include all tables from rijndael32.tab
102  * RIJNDAEL_GENERATE_VALUES        Do not store tables, generate the table
103  *                                 values "on-the-fly", using gfm
104  * RIJNDAEL_GENERATE_VALUES_MACRO  Same as above, but use macros
105  *
106  * The default is RIJNDAEL_INCLUDE_TABLES.
107  */
108 
109 /*
110  * When building RIJNDAEL_INCLUDE_TABLES, includes S**-1, Rcon, T[0..4],
111  *                                                 T**-1[0..4], IMXC[0..4]
112  * When building anything else, includes S, S**-1, Rcon
113  */
114 #include "rijndael32.tab"
115 
116 #if defined(RIJNDAEL_INCLUDE_TABLES)
117 /*
118  * RIJNDAEL_INCLUDE_TABLES
119  */
120 #define T0(i) _T0[i]
121 #define T1(i) _T1[i]
122 #define T2(i) _T2[i]
123 #define T3(i) _T3[i]
124 #define TInv0(i) _TInv0[i]
125 #define TInv1(i) _TInv1[i]
126 #define TInv2(i) _TInv2[i]
127 #define TInv3(i) _TInv3[i]
128 #define IMXC0(b) _IMXC0[b]
129 #define IMXC1(b) _IMXC1[b]
130 #define IMXC2(b) _IMXC2[b]
131 #define IMXC3(b) _IMXC3[b]
132 /* The S-box can be recovered from the T-tables */
133 #ifdef IS_LITTLE_ENDIAN
134 #define SBOX(b) ((PRUint8)_T3[b])
135 #else
136 #define SBOX(b) ((PRUint8)_T1[b])
137 #endif
138 #define SINV(b) (_SInv[b])
139 
140 #else /* not RIJNDAEL_INCLUDE_TABLES */
141 
142 /*
143  * Code for generating T-table values.
144  */
145 
146 #ifdef IS_LITTLE_ENDIAN
147 #define WORD4(b0, b1, b2, b3) \
148     ((((PRUint32)b3) << 24) | \
149      (((PRUint32)b2) << 16) | \
150      (((PRUint32)b1) << 8) |  \
151      ((PRUint32)b0))
152 #else
153 #define WORD4(b0, b1, b2, b3) \
154     ((((PRUint32)b0) << 24) | \
155      (((PRUint32)b1) << 16) | \
156      (((PRUint32)b2) << 8) |  \
157      ((PRUint32)b3))
158 #endif
159 
160 /*
161  * Define the S and S**-1 tables (both have been stored)
162  */
163 #define SBOX(b) (_S[b])
164 #define SINV(b) (_SInv[b])
165 
166 /*
167  * The function xtime, used for Galois field multiplication
168  */
169 #define XTIME(a) \
170     ((a & 0x80) ? ((a << 1) ^ 0x1b) : (a << 1))
171 
172 /* Choose GFM method (macros or function) */
173 #if defined(RIJNDAEL_GENERATE_VALUES_MACRO)
174 
175 /*
176  * Galois field GF(2**8) multipliers, in macro form
177  */
178 #define GFM01(a) \
179     (a) /* a * 01 = a, the identity */
180 #define GFM02(a) \
181     (XTIME(a) & 0xff) /* a * 02 = xtime(a) */
182 #define GFM04(a) \
183     (GFM02(GFM02(a))) /* a * 04 = xtime**2(a) */
184 #define GFM08(a) \
185     (GFM02(GFM04(a))) /* a * 08 = xtime**3(a) */
186 #define GFM03(a) \
187     (GFM01(a) ^ GFM02(a)) /* a * 03 = a * (01 + 02) */
188 #define GFM09(a) \
189     (GFM01(a) ^ GFM08(a)) /* a * 09 = a * (01 + 08) */
190 #define GFM0B(a) \
191     (GFM01(a) ^ GFM02(a) ^ GFM08(a)) /* a * 0B = a * (01 + 02 + 08) */
192 #define GFM0D(a) \
193     (GFM01(a) ^ GFM04(a) ^ GFM08(a)) /* a * 0D = a * (01 + 04 + 08) */
194 #define GFM0E(a) \
195     (GFM02(a) ^ GFM04(a) ^ GFM08(a)) /* a * 0E = a * (02 + 04 + 08) */
196 
197 #else /* RIJNDAEL_GENERATE_VALUES */
198 
199 /* GF_MULTIPLY
200  *
201  * multiply two bytes represented in GF(2**8), mod (x**4 + 1)
202  */
203 PRUint8
gfm(PRUint8 a,PRUint8 b)204 gfm(PRUint8 a, PRUint8 b)
205 {
206     PRUint8 res = 0;
207     while (b > 0) {
208         res = (b & 0x01) ? res ^ a : res;
209         a = XTIME(a);
210         b >>= 1;
211     }
212     return res;
213 }
214 
215 #define GFM01(a) \
216     (a) /* a * 01 = a, the identity */
217 #define GFM02(a) \
218     (XTIME(a) & 0xff) /* a * 02 = xtime(a) */
219 #define GFM03(a) \
220     (gfm(a, 0x03)) /* a * 03 */
221 #define GFM09(a) \
222     (gfm(a, 0x09)) /* a * 09 */
223 #define GFM0B(a) \
224     (gfm(a, 0x0B)) /* a * 0B */
225 #define GFM0D(a) \
226     (gfm(a, 0x0D)) /* a * 0D */
227 #define GFM0E(a) \
228     (gfm(a, 0x0E)) /* a * 0E */
229 
230 #endif /* choosing GFM function */
231 
232 /*
233  * The T-tables
234  */
235 #define G_T0(i) \
236     (WORD4(GFM02(SBOX(i)), GFM01(SBOX(i)), GFM01(SBOX(i)), GFM03(SBOX(i))))
237 #define G_T1(i) \
238     (WORD4(GFM03(SBOX(i)), GFM02(SBOX(i)), GFM01(SBOX(i)), GFM01(SBOX(i))))
239 #define G_T2(i) \
240     (WORD4(GFM01(SBOX(i)), GFM03(SBOX(i)), GFM02(SBOX(i)), GFM01(SBOX(i))))
241 #define G_T3(i) \
242     (WORD4(GFM01(SBOX(i)), GFM01(SBOX(i)), GFM03(SBOX(i)), GFM02(SBOX(i))))
243 
244 /*
245  * The inverse T-tables
246  */
247 #define G_TInv0(i) \
248     (WORD4(GFM0E(SINV(i)), GFM09(SINV(i)), GFM0D(SINV(i)), GFM0B(SINV(i))))
249 #define G_TInv1(i) \
250     (WORD4(GFM0B(SINV(i)), GFM0E(SINV(i)), GFM09(SINV(i)), GFM0D(SINV(i))))
251 #define G_TInv2(i) \
252     (WORD4(GFM0D(SINV(i)), GFM0B(SINV(i)), GFM0E(SINV(i)), GFM09(SINV(i))))
253 #define G_TInv3(i) \
254     (WORD4(GFM09(SINV(i)), GFM0D(SINV(i)), GFM0B(SINV(i)), GFM0E(SINV(i))))
255 
256 /*
257  * The inverse mix column tables
258  */
259 #define G_IMXC0(i) \
260     (WORD4(GFM0E(i), GFM09(i), GFM0D(i), GFM0B(i)))
261 #define G_IMXC1(i) \
262     (WORD4(GFM0B(i), GFM0E(i), GFM09(i), GFM0D(i)))
263 #define G_IMXC2(i) \
264     (WORD4(GFM0D(i), GFM0B(i), GFM0E(i), GFM09(i)))
265 #define G_IMXC3(i) \
266     (WORD4(GFM09(i), GFM0D(i), GFM0B(i), GFM0E(i)))
267 
268 /* Now choose the T-table indexing method */
269 #if defined(RIJNDAEL_GENERATE_VALUES)
270 /* generate values for the tables with a function*/
271 static PRUint32
gen_TInvXi(PRUint8 tx,PRUint8 i)272 gen_TInvXi(PRUint8 tx, PRUint8 i)
273 {
274     PRUint8 si01, si02, si03, si04, si08, si09, si0B, si0D, si0E;
275     si01 = SINV(i);
276     si02 = XTIME(si01);
277     si04 = XTIME(si02);
278     si08 = XTIME(si04);
279     si03 = si02 ^ si01;
280     si09 = si08 ^ si01;
281     si0B = si08 ^ si03;
282     si0D = si09 ^ si04;
283     si0E = si08 ^ si04 ^ si02;
284     switch (tx) {
285         case 0:
286             return WORD4(si0E, si09, si0D, si0B);
287         case 1:
288             return WORD4(si0B, si0E, si09, si0D);
289         case 2:
290             return WORD4(si0D, si0B, si0E, si09);
291         case 3:
292             return WORD4(si09, si0D, si0B, si0E);
293     }
294     return -1;
295 }
296 #define T0(i) G_T0(i)
297 #define T1(i) G_T1(i)
298 #define T2(i) G_T2(i)
299 #define T3(i) G_T3(i)
300 #define TInv0(i) gen_TInvXi(0, i)
301 #define TInv1(i) gen_TInvXi(1, i)
302 #define TInv2(i) gen_TInvXi(2, i)
303 #define TInv3(i) gen_TInvXi(3, i)
304 #define IMXC0(b) G_IMXC0(b)
305 #define IMXC1(b) G_IMXC1(b)
306 #define IMXC2(b) G_IMXC2(b)
307 #define IMXC3(b) G_IMXC3(b)
308 #else /* RIJNDAEL_GENERATE_VALUES_MACRO */
309 /* generate values for the tables with macros */
310 #define T0(i) G_T0(i)
311 #define T1(i) G_T1(i)
312 #define T2(i) G_T2(i)
313 #define T3(i) G_T3(i)
314 #define TInv0(i) G_TInv0(i)
315 #define TInv1(i) G_TInv1(i)
316 #define TInv2(i) G_TInv2(i)
317 #define TInv3(i) G_TInv3(i)
318 #define IMXC0(b) G_IMXC0(b)
319 #define IMXC1(b) G_IMXC1(b)
320 #define IMXC2(b) G_IMXC2(b)
321 #define IMXC3(b) G_IMXC3(b)
322 #endif /* choose T-table indexing method */
323 
324 #endif /* not RIJNDAEL_INCLUDE_TABLES */
325 
326 /**************************************************************************
327  *
328  * Stuff related to the Rijndael key schedule
329  *
330  *************************************************************************/
331 
332 #define SUBBYTE(w)                                \
333     ((((PRUint32)SBOX((w >> 24) & 0xff)) << 24) | \
334      (((PRUint32)SBOX((w >> 16) & 0xff)) << 16) | \
335      (((PRUint32)SBOX((w >> 8) & 0xff)) << 8) |   \
336      (((PRUint32)SBOX((w)&0xff))))
337 
338 #ifdef IS_LITTLE_ENDIAN
339 #define ROTBYTE(b) \
340     ((b >> 8) | (b << 24))
341 #else
342 #define ROTBYTE(b) \
343     ((b << 8) | (b >> 24))
344 #endif
345 
346 /* rijndael_key_expansion7
347  *
348  * Generate the expanded key from the key input by the user.
349  * XXX
350  * Nk == 7 (224 key bits) is a weird case.  Since Nk > 6, an added SubByte
351  * transformation is done periodically.  The period is every 4 bytes, and
352  * since 7%4 != 0 this happens at different times for each key word (unlike
353  * Nk == 8 where it happens twice in every key word, in the same positions).
354  * For now, I'm implementing this case "dumbly", w/o any unrolling.
355  */
356 static void
rijndael_key_expansion7(AESContext * cx,const unsigned char * key,unsigned int Nk)357 rijndael_key_expansion7(AESContext *cx, const unsigned char *key, unsigned int Nk)
358 {
359     unsigned int i;
360     PRUint32 *W;
361     PRUint32 *pW;
362     PRUint32 tmp;
363     W = cx->k.expandedKey;
364     /* 1.  the first Nk words contain the cipher key */
365     memcpy(W, key, Nk * 4);
366     i = Nk;
367     /* 2.  loop until full expanded key is obtained */
368     pW = W + i - 1;
369     for (; i < cx->Nb * (cx->Nr + 1); ++i) {
370         tmp = *pW++;
371         if (i % Nk == 0)
372             tmp = SUBBYTE(ROTBYTE(tmp)) ^ Rcon[i / Nk - 1];
373         else if (i % Nk == 4)
374             tmp = SUBBYTE(tmp);
375         *pW = W[i - Nk] ^ tmp;
376     }
377 }
378 
379 /* rijndael_key_expansion
380  *
381  * Generate the expanded key from the key input by the user.
382  */
383 static void
rijndael_key_expansion(AESContext * cx,const unsigned char * key,unsigned int Nk)384 rijndael_key_expansion(AESContext *cx, const unsigned char *key, unsigned int Nk)
385 {
386     unsigned int i;
387     PRUint32 *W;
388     PRUint32 *pW;
389     PRUint32 tmp;
390     unsigned int round_key_words = cx->Nb * (cx->Nr + 1);
391     if (Nk == 7) {
392         rijndael_key_expansion7(cx, key, Nk);
393         return;
394     }
395     W = cx->k.expandedKey;
396     /* The first Nk words contain the input cipher key */
397     memcpy(W, key, Nk * 4);
398     i = Nk;
399     pW = W + i - 1;
400     /* Loop over all sets of Nk words, except the last */
401     while (i < round_key_words - Nk) {
402         tmp = *pW++;
403         tmp = SUBBYTE(ROTBYTE(tmp)) ^ Rcon[i / Nk - 1];
404         *pW = W[i++ - Nk] ^ tmp;
405         tmp = *pW++;
406         *pW = W[i++ - Nk] ^ tmp;
407         tmp = *pW++;
408         *pW = W[i++ - Nk] ^ tmp;
409         tmp = *pW++;
410         *pW = W[i++ - Nk] ^ tmp;
411         if (Nk == 4)
412             continue;
413         switch (Nk) {
414             case 8:
415                 tmp = *pW++;
416                 tmp = SUBBYTE(tmp);
417                 *pW = W[i++ - Nk] ^ tmp;
418             case 7:
419                 tmp = *pW++;
420                 *pW = W[i++ - Nk] ^ tmp;
421             case 6:
422                 tmp = *pW++;
423                 *pW = W[i++ - Nk] ^ tmp;
424             case 5:
425                 tmp = *pW++;
426                 *pW = W[i++ - Nk] ^ tmp;
427         }
428     }
429     /* Generate the last word */
430     tmp = *pW++;
431     tmp = SUBBYTE(ROTBYTE(tmp)) ^ Rcon[i / Nk - 1];
432     *pW = W[i++ - Nk] ^ tmp;
433     /* There may be overflow here, if Nk % (Nb * (Nr + 1)) > 0.  However,
434      * since the above loop generated all but the last Nk key words, there
435      * is no more need for the SubByte transformation.
436      */
437     if (Nk < 8) {
438         for (; i < round_key_words; ++i) {
439             tmp = *pW++;
440             *pW = W[i - Nk] ^ tmp;
441         }
442     } else {
443         /* except in the case when Nk == 8.  Then one more SubByte may have
444          * to be performed, at i % Nk == 4.
445          */
446         for (; i < round_key_words; ++i) {
447             tmp = *pW++;
448             if (i % Nk == 4)
449                 tmp = SUBBYTE(tmp);
450             *pW = W[i - Nk] ^ tmp;
451         }
452     }
453 }
454 
455 /* rijndael_invkey_expansion
456  *
457  * Generate the expanded key for the inverse cipher from the key input by
458  * the user.
459  */
460 static void
rijndael_invkey_expansion(AESContext * cx,const unsigned char * key,unsigned int Nk)461 rijndael_invkey_expansion(AESContext *cx, const unsigned char *key, unsigned int Nk)
462 {
463     unsigned int r;
464     PRUint32 *roundkeyw;
465     PRUint8 *b;
466     int Nb = cx->Nb;
467     /* begins like usual key expansion ... */
468     rijndael_key_expansion(cx, key, Nk);
469     /* ... but has the additional step of InvMixColumn,
470      * excepting the first and last round keys.
471      */
472     roundkeyw = cx->k.expandedKey + cx->Nb;
473     for (r = 1; r < cx->Nr; ++r) {
474         /* each key word, roundkeyw, represents a column in the key
475          * matrix.  Each column is multiplied by the InvMixColumn matrix.
476          *   [ 0E 0B 0D 09 ]   [ b0 ]
477          *   [ 09 0E 0B 0D ] * [ b1 ]
478          *   [ 0D 09 0E 0B ]   [ b2 ]
479          *   [ 0B 0D 09 0E ]   [ b3 ]
480          */
481         b = (PRUint8 *)roundkeyw;
482         *roundkeyw++ = IMXC0(b[0]) ^ IMXC1(b[1]) ^ IMXC2(b[2]) ^ IMXC3(b[3]);
483         b = (PRUint8 *)roundkeyw;
484         *roundkeyw++ = IMXC0(b[0]) ^ IMXC1(b[1]) ^ IMXC2(b[2]) ^ IMXC3(b[3]);
485         b = (PRUint8 *)roundkeyw;
486         *roundkeyw++ = IMXC0(b[0]) ^ IMXC1(b[1]) ^ IMXC2(b[2]) ^ IMXC3(b[3]);
487         b = (PRUint8 *)roundkeyw;
488         *roundkeyw++ = IMXC0(b[0]) ^ IMXC1(b[1]) ^ IMXC2(b[2]) ^ IMXC3(b[3]);
489         if (Nb <= 4)
490             continue;
491         switch (Nb) {
492             case 8:
493                 b = (PRUint8 *)roundkeyw;
494                 *roundkeyw++ = IMXC0(b[0]) ^ IMXC1(b[1]) ^
495                                IMXC2(b[2]) ^ IMXC3(b[3]);
496             case 7:
497                 b = (PRUint8 *)roundkeyw;
498                 *roundkeyw++ = IMXC0(b[0]) ^ IMXC1(b[1]) ^
499                                IMXC2(b[2]) ^ IMXC3(b[3]);
500             case 6:
501                 b = (PRUint8 *)roundkeyw;
502                 *roundkeyw++ = IMXC0(b[0]) ^ IMXC1(b[1]) ^
503                                IMXC2(b[2]) ^ IMXC3(b[3]);
504             case 5:
505                 b = (PRUint8 *)roundkeyw;
506                 *roundkeyw++ = IMXC0(b[0]) ^ IMXC1(b[1]) ^
507                                IMXC2(b[2]) ^ IMXC3(b[3]);
508         }
509     }
510 }
511 
512 /**************************************************************************
513  *
514  * Stuff related to Rijndael encryption/decryption.
515  *
516  *************************************************************************/
517 
518 #ifdef IS_LITTLE_ENDIAN
519 #define BYTE0WORD(w) ((w)&0x000000ff)
520 #define BYTE1WORD(w) ((w)&0x0000ff00)
521 #define BYTE2WORD(w) ((w)&0x00ff0000)
522 #define BYTE3WORD(w) ((w)&0xff000000)
523 #else
524 #define BYTE0WORD(w) ((w)&0xff000000)
525 #define BYTE1WORD(w) ((w)&0x00ff0000)
526 #define BYTE2WORD(w) ((w)&0x0000ff00)
527 #define BYTE3WORD(w) ((w)&0x000000ff)
528 #endif
529 
530 typedef union {
531     PRUint32 w[4];
532     PRUint8 b[16];
533 } rijndael_state;
534 
535 #define COLUMN_0(state) state.w[0]
536 #define COLUMN_1(state) state.w[1]
537 #define COLUMN_2(state) state.w[2]
538 #define COLUMN_3(state) state.w[3]
539 
540 #define STATE_BYTE(i) state.b[i]
541 
542 // out = a ^ b
543 inline static void
xorBlock(unsigned char * out,const unsigned char * a,const unsigned char * b)544 xorBlock(unsigned char *out, const unsigned char *a, const unsigned char *b)
545 {
546     for (unsigned int j = 0; j < AES_BLOCK_SIZE; ++j) {
547         (out)[j] = (a)[j] ^ (b)[j];
548     }
549 }
550 
551 static void NO_SANITIZE_ALIGNMENT
rijndael_encryptBlock128(AESContext * cx,unsigned char * output,const unsigned char * input)552 rijndael_encryptBlock128(AESContext *cx,
553                          unsigned char *output,
554                          const unsigned char *input)
555 {
556     unsigned int r;
557     PRUint32 *roundkeyw;
558     rijndael_state state;
559     PRUint32 C0, C1, C2, C3;
560 #if defined(NSS_X86_OR_X64)
561 #define pIn input
562 #define pOut output
563 #else
564     unsigned char *pIn, *pOut;
565     PRUint32 inBuf[4], outBuf[4];
566 
567     if ((ptrdiff_t)input & 0x3) {
568         memcpy(inBuf, input, sizeof inBuf);
569         pIn = (unsigned char *)inBuf;
570     } else {
571         pIn = (unsigned char *)input;
572     }
573     if ((ptrdiff_t)output & 0x3) {
574         pOut = (unsigned char *)outBuf;
575     } else {
576         pOut = (unsigned char *)output;
577     }
578 #endif
579     roundkeyw = cx->k.expandedKey;
580     /* Step 1: Add Round Key 0 to initial state */
581     COLUMN_0(state) = *((PRUint32 *)(pIn)) ^ *roundkeyw++;
582     COLUMN_1(state) = *((PRUint32 *)(pIn + 4)) ^ *roundkeyw++;
583     COLUMN_2(state) = *((PRUint32 *)(pIn + 8)) ^ *roundkeyw++;
584     COLUMN_3(state) = *((PRUint32 *)(pIn + 12)) ^ *roundkeyw++;
585     /* Step 2: Loop over rounds [1..NR-1] */
586     for (r = 1; r < cx->Nr; ++r) {
587         /* Do ShiftRow, ByteSub, and MixColumn all at once */
588         C0 = T0(STATE_BYTE(0)) ^
589              T1(STATE_BYTE(5)) ^
590              T2(STATE_BYTE(10)) ^
591              T3(STATE_BYTE(15));
592         C1 = T0(STATE_BYTE(4)) ^
593              T1(STATE_BYTE(9)) ^
594              T2(STATE_BYTE(14)) ^
595              T3(STATE_BYTE(3));
596         C2 = T0(STATE_BYTE(8)) ^
597              T1(STATE_BYTE(13)) ^
598              T2(STATE_BYTE(2)) ^
599              T3(STATE_BYTE(7));
600         C3 = T0(STATE_BYTE(12)) ^
601              T1(STATE_BYTE(1)) ^
602              T2(STATE_BYTE(6)) ^
603              T3(STATE_BYTE(11));
604         /* Round key addition */
605         COLUMN_0(state) = C0 ^ *roundkeyw++;
606         COLUMN_1(state) = C1 ^ *roundkeyw++;
607         COLUMN_2(state) = C2 ^ *roundkeyw++;
608         COLUMN_3(state) = C3 ^ *roundkeyw++;
609     }
610     /* Step 3: Do the last round */
611     /* Final round does not employ MixColumn */
612     C0 = ((BYTE0WORD(T2(STATE_BYTE(0)))) |
613           (BYTE1WORD(T3(STATE_BYTE(5)))) |
614           (BYTE2WORD(T0(STATE_BYTE(10)))) |
615           (BYTE3WORD(T1(STATE_BYTE(15))))) ^
616          *roundkeyw++;
617     C1 = ((BYTE0WORD(T2(STATE_BYTE(4)))) |
618           (BYTE1WORD(T3(STATE_BYTE(9)))) |
619           (BYTE2WORD(T0(STATE_BYTE(14)))) |
620           (BYTE3WORD(T1(STATE_BYTE(3))))) ^
621          *roundkeyw++;
622     C2 = ((BYTE0WORD(T2(STATE_BYTE(8)))) |
623           (BYTE1WORD(T3(STATE_BYTE(13)))) |
624           (BYTE2WORD(T0(STATE_BYTE(2)))) |
625           (BYTE3WORD(T1(STATE_BYTE(7))))) ^
626          *roundkeyw++;
627     C3 = ((BYTE0WORD(T2(STATE_BYTE(12)))) |
628           (BYTE1WORD(T3(STATE_BYTE(1)))) |
629           (BYTE2WORD(T0(STATE_BYTE(6)))) |
630           (BYTE3WORD(T1(STATE_BYTE(11))))) ^
631          *roundkeyw++;
632     *((PRUint32 *)pOut) = C0;
633     *((PRUint32 *)(pOut + 4)) = C1;
634     *((PRUint32 *)(pOut + 8)) = C2;
635     *((PRUint32 *)(pOut + 12)) = C3;
636 #if defined(NSS_X86_OR_X64)
637 #undef pIn
638 #undef pOut
639 #else
640     if ((ptrdiff_t)output & 0x3) {
641         memcpy(output, outBuf, sizeof outBuf);
642     }
643 #endif
644 }
645 
646 static void NO_SANITIZE_ALIGNMENT
rijndael_decryptBlock128(AESContext * cx,unsigned char * output,const unsigned char * input)647 rijndael_decryptBlock128(AESContext *cx,
648                          unsigned char *output,
649                          const unsigned char *input)
650 {
651     int r;
652     PRUint32 *roundkeyw;
653     rijndael_state state;
654     PRUint32 C0, C1, C2, C3;
655 #if defined(NSS_X86_OR_X64)
656 #define pIn input
657 #define pOut output
658 #else
659     unsigned char *pIn, *pOut;
660     PRUint32 inBuf[4], outBuf[4];
661 
662     if ((ptrdiff_t)input & 0x3) {
663         memcpy(inBuf, input, sizeof inBuf);
664         pIn = (unsigned char *)inBuf;
665     } else {
666         pIn = (unsigned char *)input;
667     }
668     if ((ptrdiff_t)output & 0x3) {
669         pOut = (unsigned char *)outBuf;
670     } else {
671         pOut = (unsigned char *)output;
672     }
673 #endif
674     roundkeyw = cx->k.expandedKey + cx->Nb * cx->Nr + 3;
675     /* reverse the final key addition */
676     COLUMN_3(state) = *((PRUint32 *)(pIn + 12)) ^ *roundkeyw--;
677     COLUMN_2(state) = *((PRUint32 *)(pIn + 8)) ^ *roundkeyw--;
678     COLUMN_1(state) = *((PRUint32 *)(pIn + 4)) ^ *roundkeyw--;
679     COLUMN_0(state) = *((PRUint32 *)(pIn)) ^ *roundkeyw--;
680     /* Loop over rounds in reverse [NR..1] */
681     for (r = cx->Nr; r > 1; --r) {
682         /* Invert the (InvByteSub*InvMixColumn)(InvShiftRow(state)) */
683         C0 = TInv0(STATE_BYTE(0)) ^
684              TInv1(STATE_BYTE(13)) ^
685              TInv2(STATE_BYTE(10)) ^
686              TInv3(STATE_BYTE(7));
687         C1 = TInv0(STATE_BYTE(4)) ^
688              TInv1(STATE_BYTE(1)) ^
689              TInv2(STATE_BYTE(14)) ^
690              TInv3(STATE_BYTE(11));
691         C2 = TInv0(STATE_BYTE(8)) ^
692              TInv1(STATE_BYTE(5)) ^
693              TInv2(STATE_BYTE(2)) ^
694              TInv3(STATE_BYTE(15));
695         C3 = TInv0(STATE_BYTE(12)) ^
696              TInv1(STATE_BYTE(9)) ^
697              TInv2(STATE_BYTE(6)) ^
698              TInv3(STATE_BYTE(3));
699         /* Invert the key addition step */
700         COLUMN_3(state) = C3 ^ *roundkeyw--;
701         COLUMN_2(state) = C2 ^ *roundkeyw--;
702         COLUMN_1(state) = C1 ^ *roundkeyw--;
703         COLUMN_0(state) = C0 ^ *roundkeyw--;
704     }
705     /* inverse sub */
706     pOut[0] = SINV(STATE_BYTE(0));
707     pOut[1] = SINV(STATE_BYTE(13));
708     pOut[2] = SINV(STATE_BYTE(10));
709     pOut[3] = SINV(STATE_BYTE(7));
710     pOut[4] = SINV(STATE_BYTE(4));
711     pOut[5] = SINV(STATE_BYTE(1));
712     pOut[6] = SINV(STATE_BYTE(14));
713     pOut[7] = SINV(STATE_BYTE(11));
714     pOut[8] = SINV(STATE_BYTE(8));
715     pOut[9] = SINV(STATE_BYTE(5));
716     pOut[10] = SINV(STATE_BYTE(2));
717     pOut[11] = SINV(STATE_BYTE(15));
718     pOut[12] = SINV(STATE_BYTE(12));
719     pOut[13] = SINV(STATE_BYTE(9));
720     pOut[14] = SINV(STATE_BYTE(6));
721     pOut[15] = SINV(STATE_BYTE(3));
722     /* final key addition */
723     *((PRUint32 *)(pOut + 12)) ^= *roundkeyw--;
724     *((PRUint32 *)(pOut + 8)) ^= *roundkeyw--;
725     *((PRUint32 *)(pOut + 4)) ^= *roundkeyw--;
726     *((PRUint32 *)pOut) ^= *roundkeyw--;
727 #if defined(NSS_X86_OR_X64)
728 #undef pIn
729 #undef pOut
730 #else
731     if ((ptrdiff_t)output & 0x3) {
732         memcpy(output, outBuf, sizeof outBuf);
733     }
734 #endif
735 }
736 
737 /**************************************************************************
738  *
739  *  Rijndael modes of operation (ECB and CBC)
740  *
741  *************************************************************************/
742 
743 static SECStatus
rijndael_encryptECB(AESContext * cx,unsigned char * output,unsigned int * outputLen,unsigned int maxOutputLen,const unsigned char * input,unsigned int inputLen)744 rijndael_encryptECB(AESContext *cx, unsigned char *output,
745                     unsigned int *outputLen, unsigned int maxOutputLen,
746                     const unsigned char *input, unsigned int inputLen)
747 {
748     PRBool aesni = aesni_support();
749     while (inputLen > 0) {
750         if (aesni) {
751             rijndael_native_encryptBlock(cx, output, input);
752         } else {
753             rijndael_encryptBlock128(cx, output, input);
754         }
755         output += AES_BLOCK_SIZE;
756         input += AES_BLOCK_SIZE;
757         inputLen -= AES_BLOCK_SIZE;
758     }
759     return SECSuccess;
760 }
761 
762 static SECStatus
rijndael_encryptCBC(AESContext * cx,unsigned char * output,unsigned int * outputLen,unsigned int maxOutputLen,const unsigned char * input,unsigned int inputLen)763 rijndael_encryptCBC(AESContext *cx, unsigned char *output,
764                     unsigned int *outputLen, unsigned int maxOutputLen,
765                     const unsigned char *input, unsigned int inputLen)
766 {
767     unsigned char *lastblock = cx->iv;
768     unsigned char inblock[AES_BLOCK_SIZE * 8];
769     PRBool aesni = aesni_support();
770 
771     if (!inputLen)
772         return SECSuccess;
773     while (inputLen > 0) {
774         if (aesni) {
775             /* XOR with the last block (IV if first block) */
776             native_xorBlock(inblock, input, lastblock);
777             /* encrypt */
778             rijndael_native_encryptBlock(cx, output, inblock);
779         } else {
780             xorBlock(inblock, input, lastblock);
781             rijndael_encryptBlock128(cx, output, inblock);
782         }
783 
784         /* move to the next block */
785         lastblock = output;
786         output += AES_BLOCK_SIZE;
787         input += AES_BLOCK_SIZE;
788         inputLen -= AES_BLOCK_SIZE;
789     }
790     memcpy(cx->iv, lastblock, AES_BLOCK_SIZE);
791     return SECSuccess;
792 }
793 
794 static SECStatus
rijndael_decryptECB(AESContext * cx,unsigned char * output,unsigned int * outputLen,unsigned int maxOutputLen,const unsigned char * input,unsigned int inputLen)795 rijndael_decryptECB(AESContext *cx, unsigned char *output,
796                     unsigned int *outputLen, unsigned int maxOutputLen,
797                     const unsigned char *input, unsigned int inputLen)
798 {
799     PRBool aesni = aesni_support();
800     while (inputLen > 0) {
801         if (aesni) {
802             rijndael_native_decryptBlock(cx, output, input);
803         } else {
804             rijndael_decryptBlock128(cx, output, input);
805         }
806         output += AES_BLOCK_SIZE;
807         input += AES_BLOCK_SIZE;
808         inputLen -= AES_BLOCK_SIZE;
809     }
810     return SECSuccess;
811 }
812 
813 static SECStatus
rijndael_decryptCBC(AESContext * cx,unsigned char * output,unsigned int * outputLen,unsigned int maxOutputLen,const unsigned char * input,unsigned int inputLen)814 rijndael_decryptCBC(AESContext *cx, unsigned char *output,
815                     unsigned int *outputLen, unsigned int maxOutputLen,
816                     const unsigned char *input, unsigned int inputLen)
817 {
818     const unsigned char *in;
819     unsigned char *out;
820     unsigned char newIV[AES_BLOCK_SIZE];
821     PRBool aesni = aesni_support();
822 
823     if (!inputLen)
824         return SECSuccess;
825     PORT_Assert(output - input >= 0 || input - output >= (int)inputLen);
826     in = input + (inputLen - AES_BLOCK_SIZE);
827     memcpy(newIV, in, AES_BLOCK_SIZE);
828     out = output + (inputLen - AES_BLOCK_SIZE);
829     while (inputLen > AES_BLOCK_SIZE) {
830         if (aesni) {
831             // Use hardware acceleration for normal AES parameters.
832             rijndael_native_decryptBlock(cx, out, in);
833             native_xorBlock(out, out, &in[-AES_BLOCK_SIZE]);
834         } else {
835             rijndael_decryptBlock128(cx, out, in);
836             xorBlock(out, out, &in[-AES_BLOCK_SIZE]);
837         }
838         out -= AES_BLOCK_SIZE;
839         in -= AES_BLOCK_SIZE;
840         inputLen -= AES_BLOCK_SIZE;
841     }
842     if (in == input) {
843         if (aesni) {
844             rijndael_native_decryptBlock(cx, out, in);
845             native_xorBlock(out, out, cx->iv);
846         } else {
847             rijndael_decryptBlock128(cx, out, in);
848             xorBlock(out, out, cx->iv);
849         }
850     }
851     memcpy(cx->iv, newIV, AES_BLOCK_SIZE);
852     return SECSuccess;
853 }
854 
855 /************************************************************************
856  *
857  * BLAPI Interface functions
858  *
859  * The following functions implement the encryption routines defined in
860  * BLAPI for the AES cipher, Rijndael.
861  *
862  ***********************************************************************/
863 
864 AESContext *
AES_AllocateContext(void)865 AES_AllocateContext(void)
866 {
867     return PORT_ZNewAligned(AESContext, 16, mem);
868 }
869 
870 /*
871 ** Initialize a new AES context suitable for AES encryption/decryption in
872 ** the ECB or CBC mode.
873 **  "mode" the mode of operation, which must be NSS_AES or NSS_AES_CBC
874 */
875 static SECStatus
aes_InitContext(AESContext * cx,const unsigned char * key,unsigned int keysize,const unsigned char * iv,int mode,unsigned int encrypt)876 aes_InitContext(AESContext *cx, const unsigned char *key, unsigned int keysize,
877                 const unsigned char *iv, int mode, unsigned int encrypt)
878 {
879     unsigned int Nk;
880     PRBool use_hw_aes;
881     /* According to AES, block lengths are 128 and key lengths are 128, 192, or
882      * 256 bits. We support other key sizes as well [128, 256] as long as the
883      * length in bytes is divisible by 4.
884      */
885 
886     if (key == NULL ||
887         keysize < AES_BLOCK_SIZE ||
888         keysize > 32 ||
889         keysize % 4 != 0) {
890         PORT_SetError(SEC_ERROR_INVALID_ARGS);
891         return SECFailure;
892     }
893     if (mode != NSS_AES && mode != NSS_AES_CBC) {
894         PORT_SetError(SEC_ERROR_INVALID_ARGS);
895         return SECFailure;
896     }
897     if (mode == NSS_AES_CBC && iv == NULL) {
898         PORT_SetError(SEC_ERROR_INVALID_ARGS);
899         return SECFailure;
900     }
901     if (!cx) {
902         PORT_SetError(SEC_ERROR_INVALID_ARGS);
903         return SECFailure;
904     }
905 #if defined(NSS_X86_OR_X64) || defined(USE_HW_AES)
906     use_hw_aes = (aesni_support() || arm_aes_support()) && (keysize % 8) == 0;
907 #else
908     use_hw_aes = PR_FALSE;
909 #endif
910     /* Nb = (block size in bits) / 32 */
911     cx->Nb = AES_BLOCK_SIZE / 4;
912     /* Nk = (key size in bits) / 32 */
913     Nk = keysize / 4;
914     /* Obtain number of rounds from "table" */
915     cx->Nr = RIJNDAEL_NUM_ROUNDS(Nk, cx->Nb);
916     /* copy in the iv, if neccessary */
917     if (mode == NSS_AES_CBC) {
918         memcpy(cx->iv, iv, AES_BLOCK_SIZE);
919 #ifdef USE_HW_AES
920         if (use_hw_aes) {
921             cx->worker = (freeblCipherFunc)
922                 native_aes_cbc_worker(encrypt, keysize);
923         } else
924 #endif
925         {
926             cx->worker = (freeblCipherFunc)(encrypt
927                                                 ? &rijndael_encryptCBC
928                                                 : &rijndael_decryptCBC);
929         }
930     } else {
931 #ifdef USE_HW_AES
932         if (use_hw_aes) {
933             cx->worker = (freeblCipherFunc)
934                 native_aes_ecb_worker(encrypt, keysize);
935         } else
936 #endif
937         {
938             cx->worker = (freeblCipherFunc)(encrypt
939                                                 ? &rijndael_encryptECB
940                                                 : &rijndael_decryptECB);
941         }
942     }
943     PORT_Assert((cx->Nb * (cx->Nr + 1)) <= RIJNDAEL_MAX_EXP_KEY_SIZE);
944     if ((cx->Nb * (cx->Nr + 1)) > RIJNDAEL_MAX_EXP_KEY_SIZE) {
945         PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
946         return SECFailure;
947     }
948 #ifdef USE_HW_AES
949     if (use_hw_aes) {
950         native_aes_init(encrypt, keysize);
951     } else
952 #endif
953     {
954         /* Generate expanded key */
955         if (encrypt) {
956             if (use_hw_aes && (cx->mode == NSS_AES_GCM || cx->mode == NSS_AES ||
957                                cx->mode == NSS_AES_CTR)) {
958                 PORT_Assert(keysize == 16 || keysize == 24 || keysize == 32);
959                 /* Prepare hardware key for normal AES parameters. */
960                 rijndael_native_key_expansion(cx, key, Nk);
961             } else {
962                 rijndael_key_expansion(cx, key, Nk);
963             }
964         } else {
965             rijndael_invkey_expansion(cx, key, Nk);
966         }
967         BLAPI_CLEAR_STACK(256)
968     }
969     cx->worker_cx = cx;
970     cx->destroy = NULL;
971     cx->isBlock = PR_TRUE;
972     return SECSuccess;
973 }
974 
975 SECStatus
AES_InitContext(AESContext * cx,const unsigned char * key,unsigned int keysize,const unsigned char * iv,int mode,unsigned int encrypt,unsigned int blocksize)976 AES_InitContext(AESContext *cx, const unsigned char *key, unsigned int keysize,
977                 const unsigned char *iv, int mode, unsigned int encrypt,
978                 unsigned int blocksize)
979 {
980     int basemode = mode;
981     PRBool baseencrypt = encrypt;
982     SECStatus rv;
983 
984     if (blocksize != AES_BLOCK_SIZE) {
985         PORT_SetError(SEC_ERROR_INVALID_ARGS);
986         return SECFailure;
987     }
988 
989     switch (mode) {
990         case NSS_AES_CTS:
991             basemode = NSS_AES_CBC;
992             break;
993         case NSS_AES_GCM:
994         case NSS_AES_CTR:
995             basemode = NSS_AES;
996             baseencrypt = PR_TRUE;
997             break;
998     }
999     /* Make sure enough is initialized so we can safely call Destroy. */
1000     cx->worker_cx = NULL;
1001     cx->destroy = NULL;
1002     cx->mode = mode;
1003     rv = aes_InitContext(cx, key, keysize, iv, basemode, baseencrypt);
1004     if (rv != SECSuccess) {
1005         AES_DestroyContext(cx, PR_FALSE);
1006         return rv;
1007     }
1008 
1009     /* finally, set up any mode specific contexts */
1010     cx->worker_aead = 0;
1011     switch (mode) {
1012         case NSS_AES_CTS:
1013             cx->worker_cx = CTS_CreateContext(cx, cx->worker, iv);
1014             cx->worker = (freeblCipherFunc)(encrypt ? CTS_EncryptUpdate : CTS_DecryptUpdate);
1015             cx->destroy = (freeblDestroyFunc)CTS_DestroyContext;
1016             cx->isBlock = PR_FALSE;
1017             break;
1018         case NSS_AES_GCM:
1019 #if defined(INTEL_GCM) && defined(USE_HW_AES)
1020             if (aesni_support() && (keysize % 8) == 0 && avx_support() &&
1021                 clmul_support()) {
1022                 cx->worker_cx = intel_AES_GCM_CreateContext(cx, cx->worker, iv);
1023                 cx->worker = (freeblCipherFunc)(encrypt ? intel_AES_GCM_EncryptUpdate
1024                                                         : intel_AES_GCM_DecryptUpdate);
1025                 cx->worker_aead = (freeblAeadFunc)(encrypt ? intel_AES_GCM_EncryptAEAD
1026                                                            : intel_AES_GCM_DecryptAEAD);
1027                 cx->destroy = (freeblDestroyFunc)intel_AES_GCM_DestroyContext;
1028                 cx->isBlock = PR_FALSE;
1029             } else
1030 #elif defined(USE_PPC_CRYPTO) && defined(PPC_GCM)
1031             if (ppc_crypto_support() && (keysize % 8) == 0) {
1032                 cx->worker_cx = ppc_AES_GCM_CreateContext(cx, cx->worker, iv);
1033                 cx->worker = (freeblCipherFunc)(encrypt ? ppc_AES_GCM_EncryptUpdate
1034                                                         : ppc_AES_GCM_DecryptUpdate);
1035                 cx->worker_aead = (freeblAeadFunc)(encrypt ? ppc_AES_GCM_EncryptAEAD
1036                                                            : ppc_AES_GCM_DecryptAEAD);
1037                 cx->destroy = (freeblDestroyFunc)ppc_AES_GCM_DestroyContext;
1038                 cx->isBlock = PR_FALSE;
1039             } else
1040 #endif
1041             {
1042                 cx->worker_cx = GCM_CreateContext(cx, cx->worker, iv);
1043                 cx->worker = (freeblCipherFunc)(encrypt ? GCM_EncryptUpdate
1044                                                         : GCM_DecryptUpdate);
1045                 cx->worker_aead = (freeblAeadFunc)(encrypt ? GCM_EncryptAEAD
1046                                                            : GCM_DecryptAEAD);
1047 
1048                 cx->destroy = (freeblDestroyFunc)GCM_DestroyContext;
1049                 cx->isBlock = PR_FALSE;
1050             }
1051             break;
1052         case NSS_AES_CTR:
1053             cx->worker_cx = CTR_CreateContext(cx, cx->worker, iv);
1054 #if defined(USE_HW_AES) && defined(_MSC_VER) && defined(NSS_X86_OR_X64)
1055             if (aesni_support() && (keysize % 8) == 0) {
1056                 cx->worker = (freeblCipherFunc)CTR_Update_HW_AES;
1057             } else
1058 #endif
1059             {
1060                 cx->worker = (freeblCipherFunc)CTR_Update;
1061             }
1062             cx->destroy = (freeblDestroyFunc)CTR_DestroyContext;
1063             cx->isBlock = PR_FALSE;
1064             break;
1065         default:
1066             /* everything has already been set up by aes_InitContext, just
1067              * return */
1068             return SECSuccess;
1069     }
1070     /* check to see if we succeeded in getting the worker context */
1071     if (cx->worker_cx == NULL) {
1072         /* no, just destroy the existing context */
1073         cx->destroy = NULL; /* paranoia, though you can see a dozen lines */
1074                             /* below that this isn't necessary */
1075         AES_DestroyContext(cx, PR_FALSE);
1076         return SECFailure;
1077     }
1078     return SECSuccess;
1079 }
1080 
1081 /* AES_CreateContext
1082  *
1083  * create a new context for Rijndael operations
1084  */
1085 AESContext *
AES_CreateContext(const unsigned char * key,const unsigned char * iv,int mode,int encrypt,unsigned int keysize,unsigned int blocksize)1086 AES_CreateContext(const unsigned char *key, const unsigned char *iv,
1087                   int mode, int encrypt,
1088                   unsigned int keysize, unsigned int blocksize)
1089 {
1090     AESContext *cx = AES_AllocateContext();
1091     if (cx) {
1092         SECStatus rv = AES_InitContext(cx, key, keysize, iv, mode, encrypt,
1093                                        blocksize);
1094         if (rv != SECSuccess) {
1095             AES_DestroyContext(cx, PR_TRUE);
1096             cx = NULL;
1097         }
1098     }
1099     return cx;
1100 }
1101 
1102 /*
1103  * AES_DestroyContext
1104  *
1105  * Zero an AES cipher context.  If freeit is true, also free the pointer
1106  * to the context.
1107  */
1108 void
AES_DestroyContext(AESContext * cx,PRBool freeit)1109 AES_DestroyContext(AESContext *cx, PRBool freeit)
1110 {
1111     void *mem = cx->mem;
1112     if (cx->worker_cx && cx->destroy) {
1113         (*cx->destroy)(cx->worker_cx, PR_TRUE);
1114         cx->worker_cx = NULL;
1115         cx->destroy = NULL;
1116     }
1117     PORT_Memset(cx, 0, sizeof(AESContext));
1118     if (freeit) {
1119         PORT_Free(mem);
1120     } else {
1121         /* if we are not freeing the context, restore mem, We may get called
1122          * again to actually free the context */
1123         cx->mem = mem;
1124     }
1125 }
1126 
1127 /*
1128  * AES_Encrypt
1129  *
1130  * Encrypt an arbitrary-length buffer.  The output buffer must already be
1131  * allocated to at least inputLen.
1132  */
1133 SECStatus
AES_Encrypt(AESContext * cx,unsigned char * output,unsigned int * outputLen,unsigned int maxOutputLen,const unsigned char * input,unsigned int inputLen)1134 AES_Encrypt(AESContext *cx, unsigned char *output,
1135             unsigned int *outputLen, unsigned int maxOutputLen,
1136             const unsigned char *input, unsigned int inputLen)
1137 {
1138     /* Check args */
1139     SECStatus rv;
1140     if (cx == NULL || output == NULL || (input == NULL && inputLen != 0)) {
1141         PORT_SetError(SEC_ERROR_INVALID_ARGS);
1142         return SECFailure;
1143     }
1144     if (cx->isBlock && (inputLen % AES_BLOCK_SIZE != 0)) {
1145         PORT_SetError(SEC_ERROR_INPUT_LEN);
1146         return SECFailure;
1147     }
1148     if (maxOutputLen < inputLen) {
1149         PORT_SetError(SEC_ERROR_OUTPUT_LEN);
1150         return SECFailure;
1151     }
1152     *outputLen = inputLen;
1153 #if UINT_MAX > MP_32BIT_MAX
1154     /*
1155      * we can guarentee that GSM won't overlfow if we limit the input to
1156      * 2^36 bytes. For simplicity, we are limiting it to 2^32 for now.
1157      *
1158      * We do it here to cover both hardware and software GCM operations.
1159      */
1160     {
1161         PR_STATIC_ASSERT(sizeof(unsigned int) > 4);
1162     }
1163     if ((cx->mode == NSS_AES_GCM) && (inputLen > MP_32BIT_MAX)) {
1164         PORT_SetError(SEC_ERROR_OUTPUT_LEN);
1165         return SECFailure;
1166     }
1167 #else
1168     /* if we can't pass in a 32_bit number, then no such check needed */
1169     {
1170         PR_STATIC_ASSERT(sizeof(unsigned int) <= 4);
1171     }
1172 #endif
1173 
1174     rv = (*cx->worker)(cx->worker_cx, output, outputLen, maxOutputLen,
1175                        input, inputLen, AES_BLOCK_SIZE);
1176     BLAPI_CLEAR_STACK(256)
1177     return rv;
1178 }
1179 
1180 /*
1181  * AES_Decrypt
1182  *
1183  * Decrypt and arbitrary-length buffer.  The output buffer must already be
1184  * allocated to at least inputLen.
1185  */
1186 SECStatus
AES_Decrypt(AESContext * cx,unsigned char * output,unsigned int * outputLen,unsigned int maxOutputLen,const unsigned char * input,unsigned int inputLen)1187 AES_Decrypt(AESContext *cx, unsigned char *output,
1188             unsigned int *outputLen, unsigned int maxOutputLen,
1189             const unsigned char *input, unsigned int inputLen)
1190 {
1191     SECStatus rv;
1192     /* Check args */
1193     if (cx == NULL || output == NULL || (input == NULL && inputLen != 0)) {
1194         PORT_SetError(SEC_ERROR_INVALID_ARGS);
1195         return SECFailure;
1196     }
1197     if (cx->isBlock && (inputLen % AES_BLOCK_SIZE != 0)) {
1198         PORT_SetError(SEC_ERROR_INPUT_LEN);
1199         return SECFailure;
1200     }
1201     if ((cx->mode != NSS_AES_GCM) && (maxOutputLen < inputLen)) {
1202         PORT_SetError(SEC_ERROR_OUTPUT_LEN);
1203         return SECFailure;
1204     }
1205     *outputLen = inputLen;
1206     rv = (*cx->worker)(cx->worker_cx, output, outputLen, maxOutputLen,
1207                        input, inputLen, AES_BLOCK_SIZE);
1208     BLAPI_CLEAR_STACK(256)
1209     return rv;
1210 }
1211 
1212 /*
1213  * AES_Encrypt_AEAD
1214  *
1215  * Encrypt using GCM or CCM. include the nonce, extra data, and the tag
1216  */
1217 SECStatus
AES_AEAD(AESContext * cx,unsigned char * output,unsigned int * outputLen,unsigned int maxOutputLen,const unsigned char * input,unsigned int inputLen,void * params,unsigned int paramsLen,const unsigned char * aad,unsigned int aadLen)1218 AES_AEAD(AESContext *cx, unsigned char *output,
1219          unsigned int *outputLen, unsigned int maxOutputLen,
1220          const unsigned char *input, unsigned int inputLen,
1221          void *params, unsigned int paramsLen,
1222          const unsigned char *aad, unsigned int aadLen)
1223 {
1224     SECStatus rv;
1225     /* Check args */
1226     if (cx == NULL || output == NULL || (input == NULL && inputLen != 0) || (aad == NULL && aadLen != 0) || params == NULL) {
1227         PORT_SetError(SEC_ERROR_INVALID_ARGS);
1228         return SECFailure;
1229     }
1230     if (cx->worker_aead == NULL) {
1231         PORT_SetError(SEC_ERROR_NOT_INITIALIZED);
1232         return SECFailure;
1233     }
1234     if (maxOutputLen < inputLen) {
1235         PORT_SetError(SEC_ERROR_OUTPUT_LEN);
1236         return SECFailure;
1237     }
1238     *outputLen = inputLen;
1239 #if UINT_MAX > MP_32BIT_MAX
1240     /*
1241      * we can guarentee that GSM won't overlfow if we limit the input to
1242      * 2^36 bytes. For simplicity, we are limiting it to 2^32 for now.
1243      *
1244      * We do it here to cover both hardware and software GCM operations.
1245      */
1246     {
1247         PR_STATIC_ASSERT(sizeof(unsigned int) > 4);
1248     }
1249     if (inputLen > MP_32BIT_MAX) {
1250         PORT_SetError(SEC_ERROR_OUTPUT_LEN);
1251         return SECFailure;
1252     }
1253 #else
1254     /* if we can't pass in a 32_bit number, then no such check needed */
1255     {
1256         PR_STATIC_ASSERT(sizeof(unsigned int) <= 4);
1257     }
1258 #endif
1259 
1260     rv = (*cx->worker_aead)(cx->worker_cx, output, outputLen, maxOutputLen,
1261                             input, inputLen, params, paramsLen, aad, aadLen,
1262                             AES_BLOCK_SIZE);
1263     BLAPI_CLEAR_STACK(256)
1264     return rv;
1265 }
1266