1 /*	$NetBSD: cipher-ctr-mt.c,v 1.10 2019/01/27 02:08:33 pgoyette Exp $	*/
2 /*
3  * OpenSSH Multi-threaded AES-CTR Cipher
4  *
5  * Author: Benjamin Bennett <ben@psc.edu>
6  * Copyright (c) 2008 Pittsburgh Supercomputing Center. All rights reserved.
7  *
8  * Based on original OpenSSH AES-CTR cipher. Small portions remain unchanged,
9  * Copyright (c) 2003 Markus Friedl <markus@openbsd.org>
10  *
11  * Permission to use, copy, modify, and distribute this software for any
12  * purpose with or without fee is hereby granted, provided that the above
13  * copyright notice and this permission notice appear in all copies.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
16  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
17  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
18  * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
19  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
20  * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
21  * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
22  */
23 #include "includes.h"
24 __RCSID("$NetBSD: cipher-ctr-mt.c,v 1.10 2019/01/27 02:08:33 pgoyette Exp $");
25 
26 #include <sys/types.h>
27 
28 #include <stdarg.h>
29 #include <string.h>
30 
31 #include <openssl/evp.h>
32 
33 #include "xmalloc.h"
34 #include "log.h"
35 
36 #ifndef USE_BUILTIN_RIJNDAEL
37 #include <openssl/aes.h>
38 #endif
39 
40 #include <pthread.h>
41 
42 /*-------------------- TUNABLES --------------------*/
43 /* Number of pregen threads to use */
44 #define CIPHER_THREADS	2
45 
46 /* Number of keystream queues */
47 #define NUMKQ		(CIPHER_THREADS + 2)
48 
49 /* Length of a keystream queue */
50 #define KQLEN		4096
51 
52 /* Processor cacheline length */
53 #define CACHELINE_LEN	64
54 
55 /* Collect thread stats and print at cancellation when in debug mode */
56 /* #define CIPHER_THREAD_STATS */
57 
58 /* Use single-byte XOR instead of 8-byte XOR */
59 /* #define CIPHER_BYTE_XOR */
60 /*-------------------- END TUNABLES --------------------*/
61 
62 #ifdef AES_CTR_MT
63 
64 
65 const EVP_CIPHER *evp_aes_ctr_mt(void);
66 
67 #ifdef CIPHER_THREAD_STATS
68 /*
69  * Struct to collect thread stats
70  */
71 struct thread_stats {
72 	u_int	fills;
73 	u_int	skips;
74 	u_int	waits;
75 	u_int	drains;
76 };
77 
78 /*
79  * Debug print the thread stats
80  * Use with pthread_cleanup_push for displaying at thread cancellation
81  */
82 static void
thread_loop_stats(void * x)83 thread_loop_stats(void *x)
84 {
85 	struct thread_stats *s = x;
86 
87 	debug("tid %lu - %u fills, %u skips, %u waits", pthread_self(),
88 			s->fills, s->skips, s->waits);
89 }
90 
91  #define STATS_STRUCT(s)	struct thread_stats s
92  #define STATS_INIT(s)		{ memset(&s, 0, sizeof(s)); }
93  #define STATS_FILL(s)		{ s.fills++; }
94  #define STATS_SKIP(s)		{ s.skips++; }
95  #define STATS_WAIT(s)		{ s.waits++; }
96  #define STATS_DRAIN(s)		{ s.drains++; }
97 #else
98  #define STATS_STRUCT(s)
99  #define STATS_INIT(s)
100  #define STATS_FILL(s)
101  #define STATS_SKIP(s)
102  #define STATS_WAIT(s)
103  #define STATS_DRAIN(s)
104 #endif
105 
106 /* Keystream Queue state */
107 enum {
108 	KQINIT,
109 	KQEMPTY,
110 	KQFILLING,
111 	KQFULL,
112 	KQDRAINING
113 };
114 
115 /* Keystream Queue struct */
116 struct kq {
117 	u_char		keys[KQLEN][AES_BLOCK_SIZE];
118 	u_char		ctr[AES_BLOCK_SIZE];
119 	u_char		pad0[CACHELINE_LEN];
120 	volatile int	qstate;
121 	pthread_mutex_t	lock;
122 	pthread_cond_t	cond;
123 	u_char		pad1[CACHELINE_LEN];
124 };
125 
126 /* Context struct */
127 struct ssh_aes_ctr_ctx
128 {
129 	struct kq	q[NUMKQ];
130 	AES_KEY		aes_ctx;
131 	STATS_STRUCT(stats);
132 	u_char		aes_counter[AES_BLOCK_SIZE];
133 	pthread_t	tid[CIPHER_THREADS];
134 	int		state;
135 	int		qidx;
136 	int		ridx;
137 };
138 
139 /* <friedl>
140  * increment counter 'ctr',
141  * the counter is of size 'len' bytes and stored in network-byte-order.
142  * (LSB at ctr[len-1], MSB at ctr[0])
143  */
144 static void
ssh_ctr_inc(u_char * ctr,u_int len)145 ssh_ctr_inc(u_char *ctr, u_int len)
146 {
147 	int i;
148 
149 	for (i = len - 1; i >= 0; i--)
150 		if (++ctr[i])	/* continue on overflow */
151 			return;
152 }
153 
154 /*
155  * Add num to counter 'ctr'
156  */
157 static void
ssh_ctr_add(u_char * ctr,uint32_t num,u_int len)158 ssh_ctr_add(u_char *ctr, uint32_t num, u_int len)
159 {
160 	int i;
161 	uint16_t n;
162 
163 	for (n = 0, i = len - 1; i >= 0 && (num || n); i--) {
164 		n = ctr[i] + (num & 0xff) + n;
165 		num >>= 8;
166 		ctr[i] = n & 0xff;
167 		n >>= 8;
168 	}
169 }
170 
171 /*
172  * Threads may be cancelled in a pthread_cond_wait, we must free the mutex
173  */
174 static void
thread_loop_cleanup(void * x)175 thread_loop_cleanup(void *x)
176 {
177 	pthread_mutex_unlock((pthread_mutex_t *)x);
178 }
179 
180 /*
181  * The life of a pregen thread:
182  *    Find empty keystream queues and fill them using their counter.
183  *    When done, update counter for the next fill.
184  */
185 static void *
thread_loop(void * x)186 thread_loop(void *x)
187 {
188 	AES_KEY key;
189 	STATS_STRUCT(stats);
190 	struct ssh_aes_ctr_ctx *c = x;
191 	struct kq *q;
192 	int i;
193 	int qidx;
194 
195 	/* Threads stats on cancellation */
196 	STATS_INIT(stats);
197 #ifdef CIPHER_THREAD_STATS
198 	pthread_cleanup_push(thread_loop_stats, &stats);
199 #endif
200 
201 	/* Thread local copy of AES key */
202 	memcpy(&key, &c->aes_ctx, sizeof(key));
203 
204 	/*
205 	 * Handle the special case of startup, one thread must fill
206  	 * the first KQ then mark it as draining. Lock held throughout.
207  	 */
208 	if (pthread_equal(pthread_self(), c->tid[0])) {
209 		q = &c->q[0];
210 		pthread_mutex_lock(&q->lock);
211 		if (q->qstate == KQINIT) {
212 			for (i = 0; i < KQLEN; i++) {
213 				AES_encrypt(q->ctr, q->keys[i], &key);
214 				ssh_ctr_inc(q->ctr, AES_BLOCK_SIZE);
215 			}
216 			ssh_ctr_add(q->ctr, KQLEN * (NUMKQ - 1), AES_BLOCK_SIZE);
217 			q->qstate = KQDRAINING;
218 			STATS_FILL(stats);
219 			pthread_cond_broadcast(&q->cond);
220 		}
221 		pthread_mutex_unlock(&q->lock);
222 	}
223 	else
224 		STATS_SKIP(stats);
225 
226 	/*
227  	 * Normal case is to find empty queues and fill them, skipping over
228  	 * queues already filled by other threads and stopping to wait for
229  	 * a draining queue to become empty.
230  	 *
231  	 * Multiple threads may be waiting on a draining queue and awoken
232  	 * when empty.  The first thread to wake will mark it as filling,
233  	 * others will move on to fill, skip, or wait on the next queue.
234  	 */
235 	for (qidx = 1;; qidx = (qidx + 1) % NUMKQ) {
236 		/* Check if I was cancelled, also checked in cond_wait */
237 		pthread_testcancel();
238 
239 		/* Lock queue and block if its draining */
240 		q = &c->q[qidx];
241 		pthread_mutex_lock(&q->lock);
242 		pthread_cleanup_push(thread_loop_cleanup, &q->lock);
243 		while (q->qstate == KQDRAINING || q->qstate == KQINIT) {
244 			STATS_WAIT(stats);
245 			pthread_cond_wait(&q->cond, &q->lock);
246 		}
247 		pthread_cleanup_pop(0);
248 
249 		/* If filling or full, somebody else got it, skip */
250 		if (q->qstate != KQEMPTY) {
251 			pthread_mutex_unlock(&q->lock);
252 			STATS_SKIP(stats);
253 			continue;
254 		}
255 
256 		/*
257  		 * Empty, let's fill it.
258  		 * Queue lock is relinquished while we do this so others
259  		 * can see that it's being filled.
260  		 */
261 		q->qstate = KQFILLING;
262 		pthread_mutex_unlock(&q->lock);
263 		for (i = 0; i < KQLEN; i++) {
264 			AES_encrypt(q->ctr, q->keys[i], &key);
265 			ssh_ctr_inc(q->ctr, AES_BLOCK_SIZE);
266 		}
267 
268 		/* Re-lock, mark full and signal consumer */
269 		pthread_mutex_lock(&q->lock);
270 		ssh_ctr_add(q->ctr, KQLEN * (NUMKQ - 1), AES_BLOCK_SIZE);
271 		q->qstate = KQFULL;
272 		STATS_FILL(stats);
273 		pthread_cond_signal(&q->cond);
274 		pthread_mutex_unlock(&q->lock);
275 	}
276 
277 #ifdef CIPHER_THREAD_STATS
278 	/* Stats */
279 	pthread_cleanup_pop(1);
280 #endif
281 
282 	return NULL;
283 }
284 
285 static int
ssh_aes_ctr(EVP_CIPHER_CTX * ctx,u_char * dest,const u_char * src,u_int len)286 ssh_aes_ctr(EVP_CIPHER_CTX *ctx, u_char *dest, const u_char *src,
287     u_int len)
288 {
289 	struct ssh_aes_ctr_ctx *c;
290 	struct kq *q, *oldq;
291 	int ridx;
292 	u_char *buf;
293 
294 	if (len == 0)
295 		return (1);
296 	if ((c = EVP_CIPHER_CTX_get_app_data(ctx)) == NULL)
297 		return (0);
298 
299 	q = &c->q[c->qidx];
300 	ridx = c->ridx;
301 
302 	/* src already padded to block multiple */
303 	while (len > 0) {
304 		buf = q->keys[ridx];
305 
306 #ifdef CIPHER_BYTE_XOR
307 		dest[0] = src[0] ^ buf[0];
308 		dest[1] = src[1] ^ buf[1];
309 		dest[2] = src[2] ^ buf[2];
310 		dest[3] = src[3] ^ buf[3];
311 		dest[4] = src[4] ^ buf[4];
312 		dest[5] = src[5] ^ buf[5];
313 		dest[6] = src[6] ^ buf[6];
314 		dest[7] = src[7] ^ buf[7];
315 		dest[8] = src[8] ^ buf[8];
316 		dest[9] = src[9] ^ buf[9];
317 		dest[10] = src[10] ^ buf[10];
318 		dest[11] = src[11] ^ buf[11];
319 		dest[12] = src[12] ^ buf[12];
320 		dest[13] = src[13] ^ buf[13];
321 		dest[14] = src[14] ^ buf[14];
322 		dest[15] = src[15] ^ buf[15];
323 #else
324 		*(uint64_t *)dest = *(uint64_t *)src ^ *(uint64_t *)buf;
325 		*(uint64_t *)(dest + 8) = *(uint64_t *)(src + 8) ^
326 						*(uint64_t *)(buf + 8);
327 #endif
328 
329 		dest += 16;
330 		src += 16;
331 		len -= 16;
332 		ssh_ctr_inc(ctx->iv, AES_BLOCK_SIZE);
333 
334 		/* Increment read index, switch queues on rollover */
335 		if ((ridx = (ridx + 1) % KQLEN) == 0) {
336 			oldq = q;
337 
338 			/* Mark next queue draining, may need to wait */
339 			c->qidx = (c->qidx + 1) % NUMKQ;
340 			q = &c->q[c->qidx];
341 			pthread_mutex_lock(&q->lock);
342 			while (q->qstate != KQFULL) {
343 				STATS_WAIT(c->stats);
344 				pthread_cond_wait(&q->cond, &q->lock);
345 			}
346 			q->qstate = KQDRAINING;
347 			pthread_mutex_unlock(&q->lock);
348 
349 			/* Mark consumed queue empty and signal producers */
350 			pthread_mutex_lock(&oldq->lock);
351 			oldq->qstate = KQEMPTY;
352 			STATS_DRAIN(c->stats);
353 			pthread_cond_broadcast(&oldq->cond);
354 			pthread_mutex_unlock(&oldq->lock);
355 		}
356 	}
357 	c->ridx = ridx;
358 	return (1);
359 }
360 
361 #define HAVE_NONE       0
362 #define HAVE_KEY        1
363 #define HAVE_IV         2
364 static int
ssh_aes_ctr_init(EVP_CIPHER_CTX * ctx,const u_char * key,const u_char * iv,int enc)365 ssh_aes_ctr_init(EVP_CIPHER_CTX *ctx, const u_char *key, const u_char *iv,
366     int enc)
367 {
368 	struct ssh_aes_ctr_ctx *c;
369 	int i;
370 
371 	if ((c = EVP_CIPHER_CTX_get_app_data(ctx)) == NULL) {
372 		c = xmalloc(sizeof(*c));
373 
374 		c->state = HAVE_NONE;
375 		for (i = 0; i < NUMKQ; i++) {
376 			pthread_mutex_init(&c->q[i].lock, NULL);
377 			pthread_cond_init(&c->q[i].cond, NULL);
378 		}
379 
380 		STATS_INIT(c->stats);
381 
382 		EVP_CIPHER_CTX_set_app_data(ctx, c);
383 	}
384 
385 	if (c->state == (HAVE_KEY | HAVE_IV)) {
386 		/* Cancel pregen threads */
387 		for (i = 0; i < CIPHER_THREADS; i++)
388 			pthread_cancel(c->tid[i]);
389 		for (i = 0; i < CIPHER_THREADS; i++)
390 			pthread_join(c->tid[i], NULL);
391 		/* Start over getting key & iv */
392 		c->state = HAVE_NONE;
393 	}
394 
395 	if (key != NULL) {
396 		AES_set_encrypt_key(key, EVP_CIPHER_CTX_key_length(ctx) * 8,
397 		    &c->aes_ctx);
398 		c->state |= HAVE_KEY;
399 	}
400 
401 	if (iv != NULL) {
402 		memcpy(ctx->iv, iv, AES_BLOCK_SIZE);
403 		c->state |= HAVE_IV;
404 	}
405 
406 	if (c->state == (HAVE_KEY | HAVE_IV)) {
407 		/* Clear queues */
408 		memcpy(c->q[0].ctr, ctx->iv, AES_BLOCK_SIZE);
409 		c->q[0].qstate = KQINIT;
410 		for (i = 1; i < NUMKQ; i++) {
411 			memcpy(c->q[i].ctr, ctx->iv, AES_BLOCK_SIZE);
412 			ssh_ctr_add(c->q[i].ctr, i * KQLEN, AES_BLOCK_SIZE);
413 			c->q[i].qstate = KQEMPTY;
414 		}
415 		c->qidx = 0;
416 		c->ridx = 0;
417 
418 		/* Start threads */
419 		for (i = 0; i < CIPHER_THREADS; i++) {
420 			pthread_create(&c->tid[i], NULL, thread_loop, c);
421 		}
422 		pthread_mutex_lock(&c->q[0].lock);
423 		while (c->q[0].qstate != KQDRAINING)
424 			pthread_cond_wait(&c->q[0].cond, &c->q[0].lock);
425 		pthread_mutex_unlock(&c->q[0].lock);
426 
427 	}
428 	return (1);
429 }
430 
431 static int
ssh_aes_ctr_cleanup(EVP_CIPHER_CTX * ctx)432 ssh_aes_ctr_cleanup(EVP_CIPHER_CTX *ctx)
433 {
434 	struct ssh_aes_ctr_ctx *c;
435 	int i;
436 
437 	if ((c = EVP_CIPHER_CTX_get_app_data(ctx)) != NULL) {
438 #ifdef CIPHER_THREAD_STATS
439 		debug("main thread: %u drains, %u waits", c->stats.drains,
440 				c->stats.waits);
441 #endif
442 		/* Cancel pregen threads */
443 		for (i = 0; i < CIPHER_THREADS; i++)
444 			pthread_cancel(c->tid[i]);
445 		for (i = 0; i < CIPHER_THREADS; i++)
446 			pthread_join(c->tid[i], NULL);
447 
448 		memset(c, 0, sizeof(*c));
449 		free(c);
450 		EVP_CIPHER_CTX_set_app_data(ctx, NULL);
451 	}
452 	return (1);
453 }
454 
455 /* <friedl> */
456 const EVP_CIPHER *
evp_aes_ctr_mt(void)457 evp_aes_ctr_mt(void)
458 {
459 	static EVP_CIPHER aes_ctr;
460 
461 	memset(&aes_ctr, 0, sizeof(EVP_CIPHER));
462 	aes_ctr.nid = NID_undef;
463 	aes_ctr.block_size = AES_BLOCK_SIZE;
464 	aes_ctr.iv_len = AES_BLOCK_SIZE;
465 	aes_ctr.key_len = 16;
466 	aes_ctr.init = ssh_aes_ctr_init;
467 	aes_ctr.cleanup = ssh_aes_ctr_cleanup;
468 	aes_ctr.do_cipher = ssh_aes_ctr;
469 #ifndef SSH_OLD_EVP
470 	aes_ctr.flags = EVP_CIPH_CBC_MODE | EVP_CIPH_VARIABLE_LENGTH |
471 	    EVP_CIPH_ALWAYS_CALL_INIT | EVP_CIPH_CUSTOM_IV;
472 #endif
473 	return (&aes_ctr);
474 }
475 #endif /* AES_CTR_MT */
476