1 /*	$NetBSD: cipher-ctr-mt.c,v 1.8 2017/04/18 18:41:46 christos Exp $	*/
2 /*
3  * OpenSSH Multi-threaded AES-CTR Cipher
4  *
5  * Author: Benjamin Bennett <ben@psc.edu>
6  * Copyright (c) 2008 Pittsburgh Supercomputing Center. All rights reserved.
7  *
8  * Based on original OpenSSH AES-CTR cipher. Small portions remain unchanged,
9  * Copyright (c) 2003 Markus Friedl <markus@openbsd.org>
10  *
11  * Permission to use, copy, modify, and distribute this software for any
12  * purpose with or without fee is hereby granted, provided that the above
13  * copyright notice and this permission notice appear in all copies.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
16  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
17  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
18  * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
19  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
20  * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
21  * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
22  */
23 #include "includes.h"
24 
25 #include <sys/types.h>
26 
27 #include <stdarg.h>
28 #include <string.h>
29 
30 #include <openssl/evp.h>
31 
32 #include "xmalloc.h"
33 #include "log.h"
34 
35 #ifndef USE_BUILTIN_RIJNDAEL
36 #include <openssl/aes.h>
37 #endif
38 
39 #include <pthread.h>
40 
41 /*-------------------- TUNABLES --------------------*/
42 /* Number of pregen threads to use */
43 #define CIPHER_THREADS	2
44 
45 /* Number of keystream queues */
46 #define NUMKQ		(CIPHER_THREADS + 2)
47 
48 /* Length of a keystream queue */
49 #define KQLEN		4096
50 
51 /* Processor cacheline length */
52 #define CACHELINE_LEN	64
53 
54 /* Collect thread stats and print at cancellation when in debug mode */
55 /* #define CIPHER_THREAD_STATS */
56 
57 /* Use single-byte XOR instead of 8-byte XOR */
58 /* #define CIPHER_BYTE_XOR */
59 /*-------------------- END TUNABLES --------------------*/
60 
61 #ifdef AES_CTR_MT
62 
63 
64 const EVP_CIPHER *evp_aes_ctr_mt(void);
65 
66 #ifdef CIPHER_THREAD_STATS
67 /*
68  * Struct to collect thread stats
69  */
70 struct thread_stats {
71 	u_int	fills;
72 	u_int	skips;
73 	u_int	waits;
74 	u_int	drains;
75 };
76 
77 /*
78  * Debug print the thread stats
79  * Use with pthread_cleanup_push for displaying at thread cancellation
80  */
81 static void
82 thread_loop_stats(void *x)
83 {
84 	struct thread_stats *s = x;
85 
86 	debug("tid %lu - %u fills, %u skips, %u waits", pthread_self(),
87 			s->fills, s->skips, s->waits);
88 }
89 
90  #define STATS_STRUCT(s)	struct thread_stats s
91  #define STATS_INIT(s)		{ memset(&s, 0, sizeof(s)); }
92  #define STATS_FILL(s)		{ s.fills++; }
93  #define STATS_SKIP(s)		{ s.skips++; }
94  #define STATS_WAIT(s)		{ s.waits++; }
95  #define STATS_DRAIN(s)		{ s.drains++; }
96 #else
97  #define STATS_STRUCT(s)
98  #define STATS_INIT(s)
99  #define STATS_FILL(s)
100  #define STATS_SKIP(s)
101  #define STATS_WAIT(s)
102  #define STATS_DRAIN(s)
103 #endif
104 
105 /* Keystream Queue state */
106 enum {
107 	KQINIT,
108 	KQEMPTY,
109 	KQFILLING,
110 	KQFULL,
111 	KQDRAINING
112 };
113 
114 /* Keystream Queue struct */
115 struct kq {
116 	u_char		keys[KQLEN][AES_BLOCK_SIZE];
117 	u_char		ctr[AES_BLOCK_SIZE];
118 	u_char		pad0[CACHELINE_LEN];
119 	volatile int	qstate;
120 	pthread_mutex_t	lock;
121 	pthread_cond_t	cond;
122 	u_char		pad1[CACHELINE_LEN];
123 };
124 
125 /* Context struct */
126 struct ssh_aes_ctr_ctx
127 {
128 	struct kq	q[NUMKQ];
129 	AES_KEY		aes_ctx;
130 	STATS_STRUCT(stats);
131 	u_char		aes_counter[AES_BLOCK_SIZE];
132 	pthread_t	tid[CIPHER_THREADS];
133 	int		state;
134 	int		qidx;
135 	int		ridx;
136 };
137 
138 /* <friedl>
139  * increment counter 'ctr',
140  * the counter is of size 'len' bytes and stored in network-byte-order.
141  * (LSB at ctr[len-1], MSB at ctr[0])
142  */
143 static void
144 ssh_ctr_inc(u_char *ctr, u_int len)
145 {
146 	int i;
147 
148 	for (i = len - 1; i >= 0; i--)
149 		if (++ctr[i])	/* continue on overflow */
150 			return;
151 }
152 
153 /*
154  * Add num to counter 'ctr'
155  */
156 static void
157 ssh_ctr_add(u_char *ctr, uint32_t num, u_int len)
158 {
159 	int i;
160 	uint16_t n;
161 
162 	for (n = 0, i = len - 1; i >= 0 && (num || n); i--) {
163 		n = ctr[i] + (num & 0xff) + n;
164 		num >>= 8;
165 		ctr[i] = n & 0xff;
166 		n >>= 8;
167 	}
168 }
169 
170 /*
171  * Threads may be cancelled in a pthread_cond_wait, we must free the mutex
172  */
173 static void
174 thread_loop_cleanup(void *x)
175 {
176 	pthread_mutex_unlock((pthread_mutex_t *)x);
177 }
178 
179 /*
180  * The life of a pregen thread:
181  *    Find empty keystream queues and fill them using their counter.
182  *    When done, update counter for the next fill.
183  */
184 static void *
185 thread_loop(void *x)
186 {
187 	AES_KEY key;
188 	STATS_STRUCT(stats);
189 	struct ssh_aes_ctr_ctx *c = x;
190 	struct kq *q;
191 	int i;
192 	int qidx;
193 
194 	/* Threads stats on cancellation */
195 	STATS_INIT(stats);
196 #ifdef CIPHER_THREAD_STATS
197 	pthread_cleanup_push(thread_loop_stats, &stats);
198 #endif
199 
200 	/* Thread local copy of AES key */
201 	memcpy(&key, &c->aes_ctx, sizeof(key));
202 
203 	/*
204 	 * Handle the special case of startup, one thread must fill
205  	 * the first KQ then mark it as draining. Lock held throughout.
206  	 */
207 	if (pthread_equal(pthread_self(), c->tid[0])) {
208 		q = &c->q[0];
209 		pthread_mutex_lock(&q->lock);
210 		if (q->qstate == KQINIT) {
211 			for (i = 0; i < KQLEN; i++) {
212 				AES_encrypt(q->ctr, q->keys[i], &key);
213 				ssh_ctr_inc(q->ctr, AES_BLOCK_SIZE);
214 			}
215 			ssh_ctr_add(q->ctr, KQLEN * (NUMKQ - 1), AES_BLOCK_SIZE);
216 			q->qstate = KQDRAINING;
217 			STATS_FILL(stats);
218 			pthread_cond_broadcast(&q->cond);
219 		}
220 		pthread_mutex_unlock(&q->lock);
221 	}
222 	else
223 		STATS_SKIP(stats);
224 
225 	/*
226  	 * Normal case is to find empty queues and fill them, skipping over
227  	 * queues already filled by other threads and stopping to wait for
228  	 * a draining queue to become empty.
229  	 *
230  	 * Multiple threads may be waiting on a draining queue and awoken
231  	 * when empty.  The first thread to wake will mark it as filling,
232  	 * others will move on to fill, skip, or wait on the next queue.
233  	 */
234 	for (qidx = 1;; qidx = (qidx + 1) % NUMKQ) {
235 		/* Check if I was cancelled, also checked in cond_wait */
236 		pthread_testcancel();
237 
238 		/* Lock queue and block if its draining */
239 		q = &c->q[qidx];
240 		pthread_mutex_lock(&q->lock);
241 		pthread_cleanup_push(thread_loop_cleanup, &q->lock);
242 		while (q->qstate == KQDRAINING || q->qstate == KQINIT) {
243 			STATS_WAIT(stats);
244 			pthread_cond_wait(&q->cond, &q->lock);
245 		}
246 		pthread_cleanup_pop(0);
247 
248 		/* If filling or full, somebody else got it, skip */
249 		if (q->qstate != KQEMPTY) {
250 			pthread_mutex_unlock(&q->lock);
251 			STATS_SKIP(stats);
252 			continue;
253 		}
254 
255 		/*
256  		 * Empty, let's fill it.
257  		 * Queue lock is relinquished while we do this so others
258  		 * can see that it's being filled.
259  		 */
260 		q->qstate = KQFILLING;
261 		pthread_mutex_unlock(&q->lock);
262 		for (i = 0; i < KQLEN; i++) {
263 			AES_encrypt(q->ctr, q->keys[i], &key);
264 			ssh_ctr_inc(q->ctr, AES_BLOCK_SIZE);
265 		}
266 
267 		/* Re-lock, mark full and signal consumer */
268 		pthread_mutex_lock(&q->lock);
269 		ssh_ctr_add(q->ctr, KQLEN * (NUMKQ - 1), AES_BLOCK_SIZE);
270 		q->qstate = KQFULL;
271 		STATS_FILL(stats);
272 		pthread_cond_signal(&q->cond);
273 		pthread_mutex_unlock(&q->lock);
274 	}
275 
276 #ifdef CIPHER_THREAD_STATS
277 	/* Stats */
278 	pthread_cleanup_pop(1);
279 #endif
280 
281 	return NULL;
282 }
283 
284 static int
285 ssh_aes_ctr(EVP_CIPHER_CTX *ctx, u_char *dest, const u_char *src,
286     u_int len)
287 {
288 	struct ssh_aes_ctr_ctx *c;
289 	struct kq *q, *oldq;
290 	int ridx;
291 	u_char *buf;
292 
293 	if (len == 0)
294 		return (1);
295 	if ((c = EVP_CIPHER_CTX_get_app_data(ctx)) == NULL)
296 		return (0);
297 
298 	q = &c->q[c->qidx];
299 	ridx = c->ridx;
300 
301 	/* src already padded to block multiple */
302 	while (len > 0) {
303 		buf = q->keys[ridx];
304 
305 #ifdef CIPHER_BYTE_XOR
306 		dest[0] = src[0] ^ buf[0];
307 		dest[1] = src[1] ^ buf[1];
308 		dest[2] = src[2] ^ buf[2];
309 		dest[3] = src[3] ^ buf[3];
310 		dest[4] = src[4] ^ buf[4];
311 		dest[5] = src[5] ^ buf[5];
312 		dest[6] = src[6] ^ buf[6];
313 		dest[7] = src[7] ^ buf[7];
314 		dest[8] = src[8] ^ buf[8];
315 		dest[9] = src[9] ^ buf[9];
316 		dest[10] = src[10] ^ buf[10];
317 		dest[11] = src[11] ^ buf[11];
318 		dest[12] = src[12] ^ buf[12];
319 		dest[13] = src[13] ^ buf[13];
320 		dest[14] = src[14] ^ buf[14];
321 		dest[15] = src[15] ^ buf[15];
322 #else
323 		*(uint64_t *)dest = *(uint64_t *)src ^ *(uint64_t *)buf;
324 		*(uint64_t *)(dest + 8) = *(uint64_t *)(src + 8) ^
325 						*(uint64_t *)(buf + 8);
326 #endif
327 
328 		dest += 16;
329 		src += 16;
330 		len -= 16;
331 		ssh_ctr_inc(ctx->iv, AES_BLOCK_SIZE);
332 
333 		/* Increment read index, switch queues on rollover */
334 		if ((ridx = (ridx + 1) % KQLEN) == 0) {
335 			oldq = q;
336 
337 			/* Mark next queue draining, may need to wait */
338 			c->qidx = (c->qidx + 1) % NUMKQ;
339 			q = &c->q[c->qidx];
340 			pthread_mutex_lock(&q->lock);
341 			while (q->qstate != KQFULL) {
342 				STATS_WAIT(c->stats);
343 				pthread_cond_wait(&q->cond, &q->lock);
344 			}
345 			q->qstate = KQDRAINING;
346 			pthread_mutex_unlock(&q->lock);
347 
348 			/* Mark consumed queue empty and signal producers */
349 			pthread_mutex_lock(&oldq->lock);
350 			oldq->qstate = KQEMPTY;
351 			STATS_DRAIN(c->stats);
352 			pthread_cond_broadcast(&oldq->cond);
353 			pthread_mutex_unlock(&oldq->lock);
354 		}
355 	}
356 	c->ridx = ridx;
357 	return (1);
358 }
359 
360 #define HAVE_NONE       0
361 #define HAVE_KEY        1
362 #define HAVE_IV         2
363 static int
364 ssh_aes_ctr_init(EVP_CIPHER_CTX *ctx, const u_char *key, const u_char *iv,
365     int enc)
366 {
367 	struct ssh_aes_ctr_ctx *c;
368 	int i;
369 
370 	if ((c = EVP_CIPHER_CTX_get_app_data(ctx)) == NULL) {
371 		c = xmalloc(sizeof(*c));
372 
373 		c->state = HAVE_NONE;
374 		for (i = 0; i < NUMKQ; i++) {
375 			pthread_mutex_init(&c->q[i].lock, NULL);
376 			pthread_cond_init(&c->q[i].cond, NULL);
377 		}
378 
379 		STATS_INIT(c->stats);
380 
381 		EVP_CIPHER_CTX_set_app_data(ctx, c);
382 	}
383 
384 	if (c->state == (HAVE_KEY | HAVE_IV)) {
385 		/* Cancel pregen threads */
386 		for (i = 0; i < CIPHER_THREADS; i++)
387 			pthread_cancel(c->tid[i]);
388 		for (i = 0; i < CIPHER_THREADS; i++)
389 			pthread_join(c->tid[i], NULL);
390 		/* Start over getting key & iv */
391 		c->state = HAVE_NONE;
392 	}
393 
394 	if (key != NULL) {
395 		AES_set_encrypt_key(key, EVP_CIPHER_CTX_key_length(ctx) * 8,
396 		    &c->aes_ctx);
397 		c->state |= HAVE_KEY;
398 	}
399 
400 	if (iv != NULL) {
401 		memcpy(ctx->iv, iv, AES_BLOCK_SIZE);
402 		c->state |= HAVE_IV;
403 	}
404 
405 	if (c->state == (HAVE_KEY | HAVE_IV)) {
406 		/* Clear queues */
407 		memcpy(c->q[0].ctr, ctx->iv, AES_BLOCK_SIZE);
408 		c->q[0].qstate = KQINIT;
409 		for (i = 1; i < NUMKQ; i++) {
410 			memcpy(c->q[i].ctr, ctx->iv, AES_BLOCK_SIZE);
411 			ssh_ctr_add(c->q[i].ctr, i * KQLEN, AES_BLOCK_SIZE);
412 			c->q[i].qstate = KQEMPTY;
413 		}
414 		c->qidx = 0;
415 		c->ridx = 0;
416 
417 		/* Start threads */
418 		for (i = 0; i < CIPHER_THREADS; i++) {
419 			pthread_create(&c->tid[i], NULL, thread_loop, c);
420 		}
421 		pthread_mutex_lock(&c->q[0].lock);
422 		while (c->q[0].qstate != KQDRAINING)
423 			pthread_cond_wait(&c->q[0].cond, &c->q[0].lock);
424 		pthread_mutex_unlock(&c->q[0].lock);
425 
426 	}
427 	return (1);
428 }
429 
430 static int
431 ssh_aes_ctr_cleanup(EVP_CIPHER_CTX *ctx)
432 {
433 	struct ssh_aes_ctr_ctx *c;
434 	int i;
435 
436 	if ((c = EVP_CIPHER_CTX_get_app_data(ctx)) != NULL) {
437 #ifdef CIPHER_THREAD_STATS
438 		debug("main thread: %u drains, %u waits", c->stats.drains,
439 				c->stats.waits);
440 #endif
441 		/* Cancel pregen threads */
442 		for (i = 0; i < CIPHER_THREADS; i++)
443 			pthread_cancel(c->tid[i]);
444 		for (i = 0; i < CIPHER_THREADS; i++)
445 			pthread_join(c->tid[i], NULL);
446 
447 		memset(c, 0, sizeof(*c));
448 		free(c);
449 		EVP_CIPHER_CTX_set_app_data(ctx, NULL);
450 	}
451 	return (1);
452 }
453 
454 /* <friedl> */
455 const EVP_CIPHER *
456 evp_aes_ctr_mt(void)
457 {
458 	static EVP_CIPHER aes_ctr;
459 
460 	memset(&aes_ctr, 0, sizeof(EVP_CIPHER));
461 	aes_ctr.nid = NID_undef;
462 	aes_ctr.block_size = AES_BLOCK_SIZE;
463 	aes_ctr.iv_len = AES_BLOCK_SIZE;
464 	aes_ctr.key_len = 16;
465 	aes_ctr.init = ssh_aes_ctr_init;
466 	aes_ctr.cleanup = ssh_aes_ctr_cleanup;
467 	aes_ctr.do_cipher = ssh_aes_ctr;
468 #ifndef SSH_OLD_EVP
469 	aes_ctr.flags = EVP_CIPH_CBC_MODE | EVP_CIPH_VARIABLE_LENGTH |
470 	    EVP_CIPH_ALWAYS_CALL_INIT | EVP_CIPH_CUSTOM_IV;
471 #endif
472 	return (&aes_ctr);
473 }
474 #endif /* AES_CTR_MT */
475