1 /* This Source Code Form is subject to the terms of the Mozilla Public
2 * License, v. 2.0. If a copy of the MPL was not distributed with this
3 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
4
5 #ifdef FREEBL_NO_DEPEND
6 #include "stubs.h"
7 #endif
8 #include "prtypes.h"
9 #include "blapit.h"
10 #include "blapii.h"
11 #include "ctr.h"
12 #include "pkcs11t.h"
13 #include "secerr.h"
14
15 #ifdef USE_HW_AES
16 #ifdef NSS_X86_OR_X64
17 #include "intel-aes.h"
18 #endif
19 #include "rijndael.h"
20 #endif
21
22 #if defined(__ARM_NEON) || defined(__ARM_NEON__)
23 #include <arm_neon.h>
24 #endif
25
26 SECStatus
CTR_InitContext(CTRContext * ctr,void * context,freeblCipherFunc cipher,const unsigned char * param)27 CTR_InitContext(CTRContext *ctr, void *context, freeblCipherFunc cipher,
28 const unsigned char *param)
29 {
30 const CK_AES_CTR_PARAMS *ctrParams = (const CK_AES_CTR_PARAMS *)param;
31
32 if (ctrParams->ulCounterBits == 0 ||
33 ctrParams->ulCounterBits > AES_BLOCK_SIZE * PR_BITS_PER_BYTE) {
34 PORT_SetError(SEC_ERROR_INVALID_ARGS);
35 return SECFailure;
36 }
37
38 /* Invariant: 0 < ctr->bufPtr <= AES_BLOCK_SIZE */
39 ctr->checkWrap = PR_FALSE;
40 ctr->bufPtr = AES_BLOCK_SIZE; /* no unused data in the buffer */
41 ctr->cipher = cipher;
42 ctr->context = context;
43 ctr->counterBits = ctrParams->ulCounterBits;
44 if (AES_BLOCK_SIZE > sizeof(ctr->counter) ||
45 AES_BLOCK_SIZE > sizeof(ctrParams->cb)) {
46 PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
47 return SECFailure;
48 }
49 PORT_Memcpy(ctr->counter, ctrParams->cb, AES_BLOCK_SIZE);
50 if (ctr->counterBits < 64) {
51 PORT_Memcpy(ctr->counterFirst, ctr->counter, AES_BLOCK_SIZE);
52 ctr->checkWrap = PR_TRUE;
53 }
54 return SECSuccess;
55 }
56
57 CTRContext *
CTR_CreateContext(void * context,freeblCipherFunc cipher,const unsigned char * param)58 CTR_CreateContext(void *context, freeblCipherFunc cipher,
59 const unsigned char *param)
60 {
61 CTRContext *ctr;
62 SECStatus rv;
63
64 /* first fill in the Counter context */
65 ctr = PORT_ZNew(CTRContext);
66 if (ctr == NULL) {
67 return NULL;
68 }
69 rv = CTR_InitContext(ctr, context, cipher, param);
70 if (rv != SECSuccess) {
71 CTR_DestroyContext(ctr, PR_TRUE);
72 ctr = NULL;
73 }
74 return ctr;
75 }
76
77 void
CTR_DestroyContext(CTRContext * ctr,PRBool freeit)78 CTR_DestroyContext(CTRContext *ctr, PRBool freeit)
79 {
80 PORT_Memset(ctr, 0, sizeof(CTRContext));
81 if (freeit) {
82 PORT_Free(ctr);
83 }
84 }
85
86 /*
87 * Used by counter mode. Increment the counter block. Not all bits in the
88 * counter block are part of the counter, counterBits tells how many bits
89 * are part of the counter. The counter block is blocksize long. It's a
90 * big endian value.
91 *
92 * XXX Does not handle counter rollover.
93 */
94 static void
ctr_GetNextCtr(unsigned char * counter,unsigned int counterBits,unsigned int blocksize)95 ctr_GetNextCtr(unsigned char *counter, unsigned int counterBits,
96 unsigned int blocksize)
97 {
98 unsigned char *counterPtr = counter + blocksize - 1;
99 unsigned char mask, count;
100
101 PORT_Assert(counterBits <= blocksize * PR_BITS_PER_BYTE);
102 while (counterBits >= PR_BITS_PER_BYTE) {
103 if (++(*(counterPtr--))) {
104 return;
105 }
106 counterBits -= PR_BITS_PER_BYTE;
107 }
108 if (counterBits == 0) {
109 return;
110 }
111 /* increment the final partial byte */
112 mask = (1 << counterBits) - 1;
113 count = ++(*counterPtr) & mask;
114 *counterPtr = ((*counterPtr) & ~mask) | count;
115 return;
116 }
117
118 static void
ctr_xor(unsigned char * target,const unsigned char * x,const unsigned char * y,unsigned int count)119 ctr_xor(unsigned char *target, const unsigned char *x,
120 const unsigned char *y, unsigned int count)
121 {
122 unsigned int i;
123 #if defined(__ARM_NEON) || defined(__ARM_NEON__)
124 while (count >= 16) {
125 vst1q_u8(target, veorq_u8(vld1q_u8(x), vld1q_u8(y)));
126 target += 16;
127 x += 16;
128 y += 16;
129 count -= 16;
130 }
131 #endif
132 for (i = 0; i < count; i++) {
133 *target++ = *x++ ^ *y++;
134 }
135 }
136
137 SECStatus
CTR_Update(CTRContext * ctr,unsigned char * outbuf,unsigned int * outlen,unsigned int maxout,const unsigned char * inbuf,unsigned int inlen,unsigned int blocksize)138 CTR_Update(CTRContext *ctr, unsigned char *outbuf,
139 unsigned int *outlen, unsigned int maxout,
140 const unsigned char *inbuf, unsigned int inlen,
141 unsigned int blocksize)
142 {
143 unsigned int tmp;
144 SECStatus rv;
145
146 // Limit block count to 2^counterBits - 2
147 if (ctr->counterBits < (sizeof(unsigned int) * 8) &&
148 inlen > ((1 << ctr->counterBits) - 2) * AES_BLOCK_SIZE) {
149 PORT_SetError(SEC_ERROR_INPUT_LEN);
150 return SECFailure;
151 }
152 if (maxout < inlen) {
153 *outlen = inlen;
154 PORT_SetError(SEC_ERROR_OUTPUT_LEN);
155 return SECFailure;
156 }
157 *outlen = 0;
158 if (ctr->bufPtr != blocksize) {
159 unsigned int needed = PR_MIN(blocksize - ctr->bufPtr, inlen);
160 ctr_xor(outbuf, inbuf, ctr->buffer + ctr->bufPtr, needed);
161 ctr->bufPtr += needed;
162 outbuf += needed;
163 inbuf += needed;
164 *outlen += needed;
165 inlen -= needed;
166 if (inlen == 0) {
167 return SECSuccess;
168 }
169 PORT_Assert(ctr->bufPtr == blocksize);
170 }
171
172 while (inlen >= blocksize) {
173 rv = (*ctr->cipher)(ctr->context, ctr->buffer, &tmp, blocksize,
174 ctr->counter, blocksize, blocksize);
175 ctr_GetNextCtr(ctr->counter, ctr->counterBits, blocksize);
176 if (ctr->checkWrap) {
177 if (PORT_Memcmp(ctr->counter, ctr->counterFirst, blocksize) == 0) {
178 PORT_SetError(SEC_ERROR_INVALID_ARGS);
179 return SECFailure;
180 }
181 }
182 if (rv != SECSuccess) {
183 return SECFailure;
184 }
185 ctr_xor(outbuf, inbuf, ctr->buffer, blocksize);
186 outbuf += blocksize;
187 inbuf += blocksize;
188 *outlen += blocksize;
189 inlen -= blocksize;
190 }
191 if (inlen == 0) {
192 return SECSuccess;
193 }
194 rv = (*ctr->cipher)(ctr->context, ctr->buffer, &tmp, blocksize,
195 ctr->counter, blocksize, blocksize);
196 ctr_GetNextCtr(ctr->counter, ctr->counterBits, blocksize);
197 if (ctr->checkWrap) {
198 if (PORT_Memcmp(ctr->counter, ctr->counterFirst, blocksize) == 0) {
199 PORT_SetError(SEC_ERROR_INVALID_ARGS);
200 return SECFailure;
201 }
202 }
203 if (rv != SECSuccess) {
204 return SECFailure;
205 }
206 ctr_xor(outbuf, inbuf, ctr->buffer, inlen);
207 ctr->bufPtr = inlen;
208 *outlen += inlen;
209 return SECSuccess;
210 }
211
212 #if defined(USE_HW_AES) && defined(_MSC_VER) && defined(NSS_X86_OR_X64)
213 SECStatus
CTR_Update_HW_AES(CTRContext * ctr,unsigned char * outbuf,unsigned int * outlen,unsigned int maxout,const unsigned char * inbuf,unsigned int inlen,unsigned int blocksize)214 CTR_Update_HW_AES(CTRContext *ctr, unsigned char *outbuf,
215 unsigned int *outlen, unsigned int maxout,
216 const unsigned char *inbuf, unsigned int inlen,
217 unsigned int blocksize)
218 {
219 unsigned int fullblocks;
220 unsigned int tmp;
221 SECStatus rv;
222
223 // Limit block count to 2^counterBits - 2
224 if (ctr->counterBits < (sizeof(unsigned int) * 8) &&
225 inlen > ((1 << ctr->counterBits) - 2) * AES_BLOCK_SIZE) {
226 PORT_SetError(SEC_ERROR_INPUT_LEN);
227 return SECFailure;
228 }
229 if (maxout < inlen) {
230 *outlen = inlen;
231 PORT_SetError(SEC_ERROR_OUTPUT_LEN);
232 return SECFailure;
233 }
234 *outlen = 0;
235 if (ctr->bufPtr != blocksize) {
236 unsigned int needed = PR_MIN(blocksize - ctr->bufPtr, inlen);
237 ctr_xor(outbuf, inbuf, ctr->buffer + ctr->bufPtr, needed);
238 ctr->bufPtr += needed;
239 outbuf += needed;
240 inbuf += needed;
241 *outlen += needed;
242 inlen -= needed;
243 if (inlen == 0) {
244 return SECSuccess;
245 }
246 PORT_Assert(ctr->bufPtr == blocksize);
247 }
248
249 if (inlen >= blocksize) {
250 rv = intel_aes_ctr_worker(((AESContext *)(ctr->context))->Nr)(
251 ctr, outbuf, outlen, maxout, inbuf, inlen, blocksize);
252 if (rv != SECSuccess) {
253 return SECFailure;
254 }
255 fullblocks = (inlen / blocksize) * blocksize;
256 *outlen += fullblocks;
257 outbuf += fullblocks;
258 inbuf += fullblocks;
259 inlen -= fullblocks;
260 }
261
262 if (inlen == 0) {
263 return SECSuccess;
264 }
265 rv = (*ctr->cipher)(ctr->context, ctr->buffer, &tmp, blocksize,
266 ctr->counter, blocksize, blocksize);
267 ctr_GetNextCtr(ctr->counter, ctr->counterBits, blocksize);
268 if (rv != SECSuccess) {
269 return SECFailure;
270 }
271 ctr_xor(outbuf, inbuf, ctr->buffer, inlen);
272 ctr->bufPtr = inlen;
273 *outlen += inlen;
274 return SECSuccess;
275 }
276 #endif
277