1 /* This Source Code Form is subject to the terms of the Mozilla Public
2  * License, v. 2.0. If a copy of the MPL was not distributed with this
3  * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
4 
5 #ifdef FREEBL_NO_DEPEND
6 #include "stubs.h"
7 #endif
8 #include "prtypes.h"
9 #include "blapit.h"
10 #include "blapii.h"
11 #include "ctr.h"
12 #include "pkcs11t.h"
13 #include "secerr.h"
14 
15 #ifdef USE_HW_AES
16 #ifdef NSS_X86_OR_X64
17 #include "intel-aes.h"
18 #endif
19 #include "rijndael.h"
20 #endif
21 
22 #if defined(__ARM_NEON) || defined(__ARM_NEON__)
23 #include <arm_neon.h>
24 #endif
25 
26 SECStatus
CTR_InitContext(CTRContext * ctr,void * context,freeblCipherFunc cipher,const unsigned char * param)27 CTR_InitContext(CTRContext *ctr, void *context, freeblCipherFunc cipher,
28                 const unsigned char *param)
29 {
30     const CK_AES_CTR_PARAMS *ctrParams = (const CK_AES_CTR_PARAMS *)param;
31 
32     if (ctrParams->ulCounterBits == 0 ||
33         ctrParams->ulCounterBits > AES_BLOCK_SIZE * PR_BITS_PER_BYTE) {
34         PORT_SetError(SEC_ERROR_INVALID_ARGS);
35         return SECFailure;
36     }
37 
38     /* Invariant: 0 < ctr->bufPtr <= AES_BLOCK_SIZE */
39     ctr->checkWrap = PR_FALSE;
40     ctr->bufPtr = AES_BLOCK_SIZE; /* no unused data in the buffer */
41     ctr->cipher = cipher;
42     ctr->context = context;
43     ctr->counterBits = ctrParams->ulCounterBits;
44     if (AES_BLOCK_SIZE > sizeof(ctr->counter) ||
45         AES_BLOCK_SIZE > sizeof(ctrParams->cb)) {
46         PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
47         return SECFailure;
48     }
49     PORT_Memcpy(ctr->counter, ctrParams->cb, AES_BLOCK_SIZE);
50     if (ctr->counterBits < 64) {
51         PORT_Memcpy(ctr->counterFirst, ctr->counter, AES_BLOCK_SIZE);
52         ctr->checkWrap = PR_TRUE;
53     }
54     return SECSuccess;
55 }
56 
57 CTRContext *
CTR_CreateContext(void * context,freeblCipherFunc cipher,const unsigned char * param)58 CTR_CreateContext(void *context, freeblCipherFunc cipher,
59                   const unsigned char *param)
60 {
61     CTRContext *ctr;
62     SECStatus rv;
63 
64     /* first fill in the Counter context */
65     ctr = PORT_ZNew(CTRContext);
66     if (ctr == NULL) {
67         return NULL;
68     }
69     rv = CTR_InitContext(ctr, context, cipher, param);
70     if (rv != SECSuccess) {
71         CTR_DestroyContext(ctr, PR_TRUE);
72         ctr = NULL;
73     }
74     return ctr;
75 }
76 
77 void
CTR_DestroyContext(CTRContext * ctr,PRBool freeit)78 CTR_DestroyContext(CTRContext *ctr, PRBool freeit)
79 {
80     PORT_Memset(ctr, 0, sizeof(CTRContext));
81     if (freeit) {
82         PORT_Free(ctr);
83     }
84 }
85 
86 /*
87  * Used by counter mode. Increment the counter block. Not all bits in the
88  * counter block are part of the counter, counterBits tells how many bits
89  * are part of the counter. The counter block is blocksize long. It's a
90  * big endian value.
91  *
92  * XXX Does not handle counter rollover.
93  */
94 static void
ctr_GetNextCtr(unsigned char * counter,unsigned int counterBits,unsigned int blocksize)95 ctr_GetNextCtr(unsigned char *counter, unsigned int counterBits,
96                unsigned int blocksize)
97 {
98     unsigned char *counterPtr = counter + blocksize - 1;
99     unsigned char mask, count;
100 
101     PORT_Assert(counterBits <= blocksize * PR_BITS_PER_BYTE);
102     while (counterBits >= PR_BITS_PER_BYTE) {
103         if (++(*(counterPtr--))) {
104             return;
105         }
106         counterBits -= PR_BITS_PER_BYTE;
107     }
108     if (counterBits == 0) {
109         return;
110     }
111     /* increment the final partial byte */
112     mask = (1 << counterBits) - 1;
113     count = ++(*counterPtr) & mask;
114     *counterPtr = ((*counterPtr) & ~mask) | count;
115     return;
116 }
117 
118 static void
ctr_xor(unsigned char * target,const unsigned char * x,const unsigned char * y,unsigned int count)119 ctr_xor(unsigned char *target, const unsigned char *x,
120         const unsigned char *y, unsigned int count)
121 {
122     unsigned int i;
123 #if defined(__ARM_NEON) || defined(__ARM_NEON__)
124     while (count >= 16) {
125         vst1q_u8(target, veorq_u8(vld1q_u8(x), vld1q_u8(y)));
126         target += 16;
127         x += 16;
128         y += 16;
129         count -= 16;
130     }
131 #endif
132     for (i = 0; i < count; i++) {
133         *target++ = *x++ ^ *y++;
134     }
135 }
136 
137 SECStatus
CTR_Update(CTRContext * ctr,unsigned char * outbuf,unsigned int * outlen,unsigned int maxout,const unsigned char * inbuf,unsigned int inlen,unsigned int blocksize)138 CTR_Update(CTRContext *ctr, unsigned char *outbuf,
139            unsigned int *outlen, unsigned int maxout,
140            const unsigned char *inbuf, unsigned int inlen,
141            unsigned int blocksize)
142 {
143     unsigned int tmp;
144     SECStatus rv;
145 
146     // Limit block count to 2^counterBits - 2
147     if (ctr->counterBits < (sizeof(unsigned int) * 8) &&
148         inlen > ((1 << ctr->counterBits) - 2) * AES_BLOCK_SIZE) {
149         PORT_SetError(SEC_ERROR_INPUT_LEN);
150         return SECFailure;
151     }
152     if (maxout < inlen) {
153         *outlen = inlen;
154         PORT_SetError(SEC_ERROR_OUTPUT_LEN);
155         return SECFailure;
156     }
157     *outlen = 0;
158     if (ctr->bufPtr != blocksize) {
159         unsigned int needed = PR_MIN(blocksize - ctr->bufPtr, inlen);
160         ctr_xor(outbuf, inbuf, ctr->buffer + ctr->bufPtr, needed);
161         ctr->bufPtr += needed;
162         outbuf += needed;
163         inbuf += needed;
164         *outlen += needed;
165         inlen -= needed;
166         if (inlen == 0) {
167             return SECSuccess;
168         }
169         PORT_Assert(ctr->bufPtr == blocksize);
170     }
171 
172     while (inlen >= blocksize) {
173         rv = (*ctr->cipher)(ctr->context, ctr->buffer, &tmp, blocksize,
174                             ctr->counter, blocksize, blocksize);
175         ctr_GetNextCtr(ctr->counter, ctr->counterBits, blocksize);
176         if (ctr->checkWrap) {
177             if (PORT_Memcmp(ctr->counter, ctr->counterFirst, blocksize) == 0) {
178                 PORT_SetError(SEC_ERROR_INVALID_ARGS);
179                 return SECFailure;
180             }
181         }
182         if (rv != SECSuccess) {
183             return SECFailure;
184         }
185         ctr_xor(outbuf, inbuf, ctr->buffer, blocksize);
186         outbuf += blocksize;
187         inbuf += blocksize;
188         *outlen += blocksize;
189         inlen -= blocksize;
190     }
191     if (inlen == 0) {
192         return SECSuccess;
193     }
194     rv = (*ctr->cipher)(ctr->context, ctr->buffer, &tmp, blocksize,
195                         ctr->counter, blocksize, blocksize);
196     ctr_GetNextCtr(ctr->counter, ctr->counterBits, blocksize);
197     if (ctr->checkWrap) {
198         if (PORT_Memcmp(ctr->counter, ctr->counterFirst, blocksize) == 0) {
199             PORT_SetError(SEC_ERROR_INVALID_ARGS);
200             return SECFailure;
201         }
202     }
203     if (rv != SECSuccess) {
204         return SECFailure;
205     }
206     ctr_xor(outbuf, inbuf, ctr->buffer, inlen);
207     ctr->bufPtr = inlen;
208     *outlen += inlen;
209     return SECSuccess;
210 }
211 
212 #if defined(USE_HW_AES) && defined(_MSC_VER) && defined(NSS_X86_OR_X64)
213 SECStatus
CTR_Update_HW_AES(CTRContext * ctr,unsigned char * outbuf,unsigned int * outlen,unsigned int maxout,const unsigned char * inbuf,unsigned int inlen,unsigned int blocksize)214 CTR_Update_HW_AES(CTRContext *ctr, unsigned char *outbuf,
215                   unsigned int *outlen, unsigned int maxout,
216                   const unsigned char *inbuf, unsigned int inlen,
217                   unsigned int blocksize)
218 {
219     unsigned int fullblocks;
220     unsigned int tmp;
221     SECStatus rv;
222 
223     // Limit block count to 2^counterBits - 2
224     if (ctr->counterBits < (sizeof(unsigned int) * 8) &&
225         inlen > ((1 << ctr->counterBits) - 2) * AES_BLOCK_SIZE) {
226         PORT_SetError(SEC_ERROR_INPUT_LEN);
227         return SECFailure;
228     }
229     if (maxout < inlen) {
230         *outlen = inlen;
231         PORT_SetError(SEC_ERROR_OUTPUT_LEN);
232         return SECFailure;
233     }
234     *outlen = 0;
235     if (ctr->bufPtr != blocksize) {
236         unsigned int needed = PR_MIN(blocksize - ctr->bufPtr, inlen);
237         ctr_xor(outbuf, inbuf, ctr->buffer + ctr->bufPtr, needed);
238         ctr->bufPtr += needed;
239         outbuf += needed;
240         inbuf += needed;
241         *outlen += needed;
242         inlen -= needed;
243         if (inlen == 0) {
244             return SECSuccess;
245         }
246         PORT_Assert(ctr->bufPtr == blocksize);
247     }
248 
249     if (inlen >= blocksize) {
250         rv = intel_aes_ctr_worker(((AESContext *)(ctr->context))->Nr)(
251             ctr, outbuf, outlen, maxout, inbuf, inlen, blocksize);
252         if (rv != SECSuccess) {
253             return SECFailure;
254         }
255         fullblocks = (inlen / blocksize) * blocksize;
256         *outlen += fullblocks;
257         outbuf += fullblocks;
258         inbuf += fullblocks;
259         inlen -= fullblocks;
260     }
261 
262     if (inlen == 0) {
263         return SECSuccess;
264     }
265     rv = (*ctr->cipher)(ctr->context, ctr->buffer, &tmp, blocksize,
266                         ctr->counter, blocksize, blocksize);
267     ctr_GetNextCtr(ctr->counter, ctr->counterBits, blocksize);
268     if (rv != SECSuccess) {
269         return SECFailure;
270     }
271     ctr_xor(outbuf, inbuf, ctr->buffer, inlen);
272     ctr->bufPtr = inlen;
273     *outlen += inlen;
274     return SECSuccess;
275 }
276 #endif
277