1 /*
2  * Copyright (C) 2012 Michael Brown <mbrown@fensystems.co.uk>.
3  *
4  * This program is free software; you can redistribute it and/or
5  * modify it under the terms of the GNU General Public License as
6  * published by the Free Software Foundation; either version 2 of the
7  * License, or any later version.
8  *
9  * This program is distributed in the hope that it will be useful, but
10  * WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
12  * General Public License for more details.
13  *
14  * You should have received a copy of the GNU General Public License
15  * along with this program; if not, write to the Free Software
16  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
17  * 02110-1301, USA.
18  *
19  * You can also choose to distribute this program under the terms of
20  * the Unmodified Binary Distribution Licence (as given in the file
21  * COPYING.UBDL), provided that you have satisfied its requirements.
22  */
23 
24 FILE_LICENCE ( GPL2_OR_LATER_OR_UBDL );
25 
26 #include <stdint.h>
27 #include <stdlib.h>
28 #include <stdarg.h>
29 #include <string.h>
30 #include <errno.h>
31 #include <ipxe/asn1.h>
32 #include <ipxe/crypto.h>
33 #include <ipxe/bigint.h>
34 #include <ipxe/random_nz.h>
35 #include <ipxe/rsa.h>
36 
37 /** @file
38  *
39  * RSA public-key cryptography
40  *
41  * RSA is documented in RFC 3447.
42  */
43 
44 /* Disambiguate the various error causes */
45 #define EACCES_VERIFY \
46 	__einfo_error ( EINFO_EACCES_VERIFY )
47 #define EINFO_EACCES_VERIFY \
48 	__einfo_uniqify ( EINFO_EACCES, 0x01, "RSA signature incorrect" )
49 
50 /**
51  * Identify RSA prefix
52  *
53  * @v digest		Digest algorithm
54  * @ret prefix		RSA prefix, or NULL
55  */
56 static struct rsa_digestinfo_prefix *
rsa_find_prefix(struct digest_algorithm * digest)57 rsa_find_prefix ( struct digest_algorithm *digest ) {
58 	struct rsa_digestinfo_prefix *prefix;
59 
60 	for_each_table_entry ( prefix, RSA_DIGESTINFO_PREFIXES ) {
61 		if ( prefix->digest == digest )
62 			return prefix;
63 	}
64 	return NULL;
65 }
66 
67 /**
68  * Free RSA dynamic storage
69  *
70  * @v context		RSA context
71  */
rsa_free(struct rsa_context * context)72 static void rsa_free ( struct rsa_context *context ) {
73 
74 	free ( context->dynamic );
75 	context->dynamic = NULL;
76 }
77 
78 /**
79  * Allocate RSA dynamic storage
80  *
81  * @v context		RSA context
82  * @v modulus_len	Modulus length
83  * @v exponent_len	Exponent length
84  * @ret rc		Return status code
85  */
rsa_alloc(struct rsa_context * context,size_t modulus_len,size_t exponent_len)86 static int rsa_alloc ( struct rsa_context *context, size_t modulus_len,
87 		       size_t exponent_len ) {
88 	unsigned int size = bigint_required_size ( modulus_len );
89 	unsigned int exponent_size = bigint_required_size ( exponent_len );
90 	bigint_t ( size ) *modulus;
91 	bigint_t ( exponent_size ) *exponent;
92 	size_t tmp_len = bigint_mod_exp_tmp_len ( modulus, exponent );
93 	struct {
94 		bigint_t ( size ) modulus;
95 		bigint_t ( exponent_size ) exponent;
96 		bigint_t ( size ) input;
97 		bigint_t ( size ) output;
98 		uint8_t tmp[tmp_len];
99 	} __attribute__ (( packed )) *dynamic;
100 
101 	/* Free any existing dynamic storage */
102 	rsa_free ( context );
103 
104 	/* Allocate dynamic storage */
105 	dynamic = malloc ( sizeof ( *dynamic ) );
106 	if ( ! dynamic )
107 		return -ENOMEM;
108 
109 	/* Assign dynamic storage */
110 	context->dynamic = dynamic;
111 	context->modulus0 = &dynamic->modulus.element[0];
112 	context->size = size;
113 	context->max_len = modulus_len;
114 	context->exponent0 = &dynamic->exponent.element[0];
115 	context->exponent_size = exponent_size;
116 	context->input0 = &dynamic->input.element[0];
117 	context->output0 = &dynamic->output.element[0];
118 	context->tmp = &dynamic->tmp;
119 
120 	return 0;
121 }
122 
123 /**
124  * Parse RSA integer
125  *
126  * @v integer		Integer to fill in
127  * @v raw		ASN.1 cursor
128  * @ret rc		Return status code
129  */
rsa_parse_integer(struct asn1_cursor * integer,const struct asn1_cursor * raw)130 static int rsa_parse_integer ( struct asn1_cursor *integer,
131 			       const struct asn1_cursor *raw ) {
132 
133 	/* Enter integer */
134 	memcpy ( integer, raw, sizeof ( *integer ) );
135 	asn1_enter ( integer, ASN1_INTEGER );
136 
137 	/* Skip initial sign byte if applicable */
138 	if ( ( integer->len > 1 ) &&
139 	     ( *( ( uint8_t * ) integer->data ) == 0x00 ) ) {
140 		integer->data++;
141 		integer->len--;
142 	}
143 
144 	/* Fail if cursor or integer are invalid */
145 	if ( ! integer->len )
146 		return -EINVAL;
147 
148 	return 0;
149 }
150 
151 /**
152  * Parse RSA modulus and exponent
153  *
154  * @v modulus		Modulus to fill in
155  * @v exponent		Exponent to fill in
156  * @v raw		ASN.1 cursor
157  * @ret rc		Return status code
158  */
rsa_parse_mod_exp(struct asn1_cursor * modulus,struct asn1_cursor * exponent,const struct asn1_cursor * raw)159 static int rsa_parse_mod_exp ( struct asn1_cursor *modulus,
160 			       struct asn1_cursor *exponent,
161 			       const struct asn1_cursor *raw ) {
162 	struct asn1_bit_string bits;
163 	struct asn1_cursor cursor;
164 	int is_private;
165 	int rc;
166 
167 	/* Enter subjectPublicKeyInfo/RSAPrivateKey */
168 	memcpy ( &cursor, raw, sizeof ( cursor ) );
169 	asn1_enter ( &cursor, ASN1_SEQUENCE );
170 
171 	/* Determine key format */
172 	if ( asn1_type ( &cursor ) == ASN1_INTEGER ) {
173 
174 		/* Private key */
175 		is_private = 1;
176 
177 		/* Skip version */
178 		asn1_skip_any ( &cursor );
179 
180 	} else {
181 
182 		/* Public key */
183 		is_private = 0;
184 
185 		/* Skip algorithm */
186 		asn1_skip ( &cursor, ASN1_SEQUENCE );
187 
188 		/* Enter subjectPublicKey */
189 		if ( ( rc = asn1_integral_bit_string ( &cursor, &bits ) ) != 0 )
190 			return rc;
191 		cursor.data = bits.data;
192 		cursor.len = bits.len;
193 
194 		/* Enter RSAPublicKey */
195 		asn1_enter ( &cursor, ASN1_SEQUENCE );
196 	}
197 
198 	/* Extract modulus */
199 	if ( ( rc = rsa_parse_integer ( modulus, &cursor ) ) != 0 )
200 		return rc;
201 	asn1_skip_any ( &cursor );
202 
203 	/* Skip public exponent, if applicable */
204 	if ( is_private )
205 		asn1_skip ( &cursor, ASN1_INTEGER );
206 
207 	/* Extract publicExponent/privateExponent */
208 	if ( ( rc = rsa_parse_integer ( exponent, &cursor ) ) != 0 )
209 		return rc;
210 
211 	return 0;
212 }
213 
214 /**
215  * Initialise RSA cipher
216  *
217  * @v ctx		RSA context
218  * @v key		Key
219  * @v key_len		Length of key
220  * @ret rc		Return status code
221  */
rsa_init(void * ctx,const void * key,size_t key_len)222 static int rsa_init ( void *ctx, const void *key, size_t key_len ) {
223 	struct rsa_context *context = ctx;
224 	struct asn1_cursor modulus;
225 	struct asn1_cursor exponent;
226 	struct asn1_cursor cursor;
227 	int rc;
228 
229 	/* Initialise context */
230 	memset ( context, 0, sizeof ( *context ) );
231 
232 	/* Initialise cursor */
233 	cursor.data = key;
234 	cursor.len = key_len;
235 
236 	/* Parse modulus and exponent */
237 	if ( ( rc = rsa_parse_mod_exp ( &modulus, &exponent, &cursor ) ) != 0 ){
238 		DBGC ( context, "RSA %p invalid modulus/exponent:\n", context );
239 		DBGC_HDA ( context, 0, cursor.data, cursor.len );
240 		goto err_parse;
241 	}
242 
243 	DBGC ( context, "RSA %p modulus:\n", context );
244 	DBGC_HDA ( context, 0, modulus.data, modulus.len );
245 	DBGC ( context, "RSA %p exponent:\n", context );
246 	DBGC_HDA ( context, 0, exponent.data, exponent.len );
247 
248 	/* Allocate dynamic storage */
249 	if ( ( rc = rsa_alloc ( context, modulus.len, exponent.len ) ) != 0 )
250 		goto err_alloc;
251 
252 	/* Construct big integers */
253 	bigint_init ( ( ( bigint_t ( context->size ) * ) context->modulus0 ),
254 		      modulus.data, modulus.len );
255 	bigint_init ( ( ( bigint_t ( context->exponent_size ) * )
256 			context->exponent0 ), exponent.data, exponent.len );
257 
258 	return 0;
259 
260 	rsa_free ( context );
261  err_alloc:
262  err_parse:
263 	return rc;
264 }
265 
266 /**
267  * Calculate RSA maximum output length
268  *
269  * @v ctx		RSA context
270  * @ret max_len		Maximum output length
271  */
rsa_max_len(void * ctx)272 static size_t rsa_max_len ( void *ctx ) {
273 	struct rsa_context *context = ctx;
274 
275 	return context->max_len;
276 }
277 
278 /**
279  * Perform RSA cipher operation
280  *
281  * @v context		RSA context
282  * @v in		Input buffer
283  * @v out		Output buffer
284  */
rsa_cipher(struct rsa_context * context,const void * in,void * out)285 static void rsa_cipher ( struct rsa_context *context,
286 			 const void *in, void *out ) {
287 	bigint_t ( context->size ) *input = ( ( void * ) context->input0 );
288 	bigint_t ( context->size ) *output = ( ( void * ) context->output0 );
289 	bigint_t ( context->size ) *modulus = ( ( void * ) context->modulus0 );
290 	bigint_t ( context->exponent_size ) *exponent =
291 		( ( void * ) context->exponent0 );
292 
293 	/* Initialise big integer */
294 	bigint_init ( input, in, context->max_len );
295 
296 	/* Perform modular exponentiation */
297 	bigint_mod_exp ( input, modulus, exponent, output, context->tmp );
298 
299 	/* Copy out result */
300 	bigint_done ( output, out, context->max_len );
301 }
302 
303 /**
304  * Encrypt using RSA
305  *
306  * @v ctx		RSA context
307  * @v plaintext		Plaintext
308  * @v plaintext_len	Length of plaintext
309  * @v ciphertext	Ciphertext
310  * @ret ciphertext_len	Length of ciphertext, or negative error
311  */
rsa_encrypt(void * ctx,const void * plaintext,size_t plaintext_len,void * ciphertext)312 static int rsa_encrypt ( void *ctx, const void *plaintext,
313 			 size_t plaintext_len, void *ciphertext ) {
314 	struct rsa_context *context = ctx;
315 	void *temp;
316 	uint8_t *encoded;
317 	size_t max_len = ( context->max_len - 11 );
318 	size_t random_nz_len = ( max_len - plaintext_len + 8 );
319 	int rc;
320 
321 	/* Sanity check */
322 	if ( plaintext_len > max_len ) {
323 		DBGC ( context, "RSA %p plaintext too long (%zd bytes, max "
324 		       "%zd)\n", context, plaintext_len, max_len );
325 		return -ERANGE;
326 	}
327 	DBGC ( context, "RSA %p encrypting:\n", context );
328 	DBGC_HDA ( context, 0, plaintext, plaintext_len );
329 
330 	/* Construct encoded message (using the big integer output
331 	 * buffer as temporary storage)
332 	 */
333 	temp = context->output0;
334 	encoded = temp;
335 	encoded[0] = 0x00;
336 	encoded[1] = 0x02;
337 	if ( ( rc = get_random_nz ( &encoded[2], random_nz_len ) ) != 0 ) {
338 		DBGC ( context, "RSA %p could not generate random data: %s\n",
339 		       context, strerror ( rc ) );
340 		return rc;
341 	}
342 	encoded[ 2 + random_nz_len ] = 0x00;
343 	memcpy ( &encoded[ context->max_len - plaintext_len ],
344 		 plaintext, plaintext_len );
345 
346 	/* Encipher the encoded message */
347 	rsa_cipher ( context, encoded, ciphertext );
348 	DBGC ( context, "RSA %p encrypted:\n", context );
349 	DBGC_HDA ( context, 0, ciphertext, context->max_len );
350 
351 	return context->max_len;
352 }
353 
354 /**
355  * Decrypt using RSA
356  *
357  * @v ctx		RSA context
358  * @v ciphertext	Ciphertext
359  * @v ciphertext_len	Ciphertext length
360  * @v plaintext		Plaintext
361  * @ret plaintext_len	Plaintext length, or negative error
362  */
rsa_decrypt(void * ctx,const void * ciphertext,size_t ciphertext_len,void * plaintext)363 static int rsa_decrypt ( void *ctx, const void *ciphertext,
364 			 size_t ciphertext_len, void *plaintext ) {
365 	struct rsa_context *context = ctx;
366 	void *temp;
367 	uint8_t *encoded;
368 	uint8_t *end;
369 	uint8_t *zero;
370 	uint8_t *start;
371 	size_t plaintext_len;
372 
373 	/* Sanity check */
374 	if ( ciphertext_len != context->max_len ) {
375 		DBGC ( context, "RSA %p ciphertext incorrect length (%zd "
376 		       "bytes, should be %zd)\n",
377 		       context, ciphertext_len, context->max_len );
378 		return -ERANGE;
379 	}
380 	DBGC ( context, "RSA %p decrypting:\n", context );
381 	DBGC_HDA ( context, 0, ciphertext, ciphertext_len );
382 
383 	/* Decipher the message (using the big integer input buffer as
384 	 * temporary storage)
385 	 */
386 	temp = context->input0;
387 	encoded = temp;
388 	rsa_cipher ( context, ciphertext, encoded );
389 
390 	/* Parse the message */
391 	end = ( encoded + context->max_len );
392 	if ( ( encoded[0] != 0x00 ) || ( encoded[1] != 0x02 ) )
393 		goto invalid;
394 	zero = memchr ( &encoded[2], 0, ( end - &encoded[2] ) );
395 	if ( ! zero )
396 		goto invalid;
397 	start = ( zero + 1 );
398 	plaintext_len = ( end - start );
399 
400 	/* Copy out message */
401 	memcpy ( plaintext, start, plaintext_len );
402 	DBGC ( context, "RSA %p decrypted:\n", context );
403 	DBGC_HDA ( context, 0, plaintext, plaintext_len );
404 
405 	return plaintext_len;
406 
407  invalid:
408 	DBGC ( context, "RSA %p invalid decrypted message:\n", context );
409 	DBGC_HDA ( context, 0, encoded, context->max_len );
410 	return -EINVAL;
411 }
412 
413 /**
414  * Encode RSA digest
415  *
416  * @v context		RSA context
417  * @v digest		Digest algorithm
418  * @v value		Digest value
419  * @v encoded		Encoded digest
420  * @ret rc		Return status code
421  */
rsa_encode_digest(struct rsa_context * context,struct digest_algorithm * digest,const void * value,void * encoded)422 static int rsa_encode_digest ( struct rsa_context *context,
423 			       struct digest_algorithm *digest,
424 			       const void *value, void *encoded ) {
425 	struct rsa_digestinfo_prefix *prefix;
426 	size_t digest_len = digest->digestsize;
427 	uint8_t *temp = encoded;
428 	size_t digestinfo_len;
429 	size_t max_len;
430 	size_t pad_len;
431 
432 	/* Identify prefix */
433 	prefix = rsa_find_prefix ( digest );
434 	if ( ! prefix ) {
435 		DBGC ( context, "RSA %p has no prefix for %s\n",
436 		       context, digest->name );
437 		return -ENOTSUP;
438 	}
439 	digestinfo_len = ( prefix->len + digest_len );
440 
441 	/* Sanity check */
442 	max_len = ( context->max_len - 11 );
443 	if ( digestinfo_len > max_len ) {
444 		DBGC ( context, "RSA %p %s digestInfo too long (%zd bytes, max"
445 		       "%zd)\n",
446 		       context, digest->name, digestinfo_len, max_len );
447 		return -ERANGE;
448 	}
449 	DBGC ( context, "RSA %p encoding %s digest:\n",
450 	       context, digest->name );
451 	DBGC_HDA ( context, 0, value, digest_len );
452 
453 	/* Construct encoded message */
454 	*(temp++) = 0x00;
455 	*(temp++) = 0x01;
456 	pad_len = ( max_len - digestinfo_len + 8 );
457 	memset ( temp, 0xff, pad_len );
458 	temp += pad_len;
459 	*(temp++) = 0x00;
460 	memcpy ( temp, prefix->data, prefix->len );
461 	temp += prefix->len;
462 	memcpy ( temp, value, digest_len );
463 	temp += digest_len;
464 	assert ( temp == ( encoded + context->max_len ) );
465 	DBGC ( context, "RSA %p encoded %s digest:\n", context, digest->name );
466 	DBGC_HDA ( context, 0, encoded, context->max_len );
467 
468 	return 0;
469 }
470 
471 /**
472  * Sign digest value using RSA
473  *
474  * @v ctx		RSA context
475  * @v digest		Digest algorithm
476  * @v value		Digest value
477  * @v signature		Signature
478  * @ret signature_len	Signature length, or negative error
479  */
rsa_sign(void * ctx,struct digest_algorithm * digest,const void * value,void * signature)480 static int rsa_sign ( void *ctx, struct digest_algorithm *digest,
481 		      const void *value, void *signature ) {
482 	struct rsa_context *context = ctx;
483 	void *temp;
484 	int rc;
485 
486 	DBGC ( context, "RSA %p signing %s digest:\n", context, digest->name );
487 	DBGC_HDA ( context, 0, value, digest->digestsize );
488 
489 	/* Encode digest (using the big integer output buffer as
490 	 * temporary storage)
491 	 */
492 	temp = context->output0;
493 	if ( ( rc = rsa_encode_digest ( context, digest, value, temp ) ) != 0 )
494 		return rc;
495 
496 	/* Encipher the encoded digest */
497 	rsa_cipher ( context, temp, signature );
498 	DBGC ( context, "RSA %p signed %s digest:\n", context, digest->name );
499 	DBGC_HDA ( context, 0, signature, context->max_len );
500 
501 	return context->max_len;
502 }
503 
504 /**
505  * Verify signed digest value using RSA
506  *
507  * @v ctx		RSA context
508  * @v digest		Digest algorithm
509  * @v value		Digest value
510  * @v signature		Signature
511  * @v signature_len	Signature length
512  * @ret rc		Return status code
513  */
rsa_verify(void * ctx,struct digest_algorithm * digest,const void * value,const void * signature,size_t signature_len)514 static int rsa_verify ( void *ctx, struct digest_algorithm *digest,
515 			const void *value, const void *signature,
516 			size_t signature_len ) {
517 	struct rsa_context *context = ctx;
518 	void *temp;
519 	void *expected;
520 	void *actual;
521 	int rc;
522 
523 	/* Sanity check */
524 	if ( signature_len != context->max_len ) {
525 		DBGC ( context, "RSA %p signature incorrect length (%zd "
526 		       "bytes, should be %zd)\n",
527 		       context, signature_len, context->max_len );
528 		return -ERANGE;
529 	}
530 	DBGC ( context, "RSA %p verifying %s digest:\n",
531 	       context, digest->name );
532 	DBGC_HDA ( context, 0, value, digest->digestsize );
533 	DBGC_HDA ( context, 0, signature, signature_len );
534 
535 	/* Decipher the signature (using the big integer input buffer
536 	 * as temporary storage)
537 	 */
538 	temp = context->input0;
539 	expected = temp;
540 	rsa_cipher ( context, signature, expected );
541 	DBGC ( context, "RSA %p deciphered signature:\n", context );
542 	DBGC_HDA ( context, 0, expected, context->max_len );
543 
544 	/* Encode digest (using the big integer output buffer as
545 	 * temporary storage)
546 	 */
547 	temp = context->output0;
548 	actual = temp;
549 	if ( ( rc = rsa_encode_digest ( context, digest, value, actual ) ) !=0 )
550 		return rc;
551 
552 	/* Verify the signature */
553 	if ( memcmp ( actual, expected, context->max_len ) != 0 ) {
554 		DBGC ( context, "RSA %p signature verification failed\n",
555 		       context );
556 		return -EACCES_VERIFY;
557 	}
558 
559 	DBGC ( context, "RSA %p signature verified successfully\n", context );
560 	return 0;
561 }
562 
563 /**
564  * Finalise RSA cipher
565  *
566  * @v ctx		RSA context
567  */
rsa_final(void * ctx)568 static void rsa_final ( void *ctx ) {
569 	struct rsa_context *context = ctx;
570 
571 	rsa_free ( context );
572 }
573 
574 /**
575  * Check for matching RSA public/private key pair
576  *
577  * @v private_key	Private key
578  * @v private_key_len	Private key length
579  * @v public_key	Public key
580  * @v public_key_len	Public key length
581  * @ret rc		Return status code
582  */
rsa_match(const void * private_key,size_t private_key_len,const void * public_key,size_t public_key_len)583 static int rsa_match ( const void *private_key, size_t private_key_len,
584 		       const void *public_key, size_t public_key_len ) {
585 	struct asn1_cursor private_modulus;
586 	struct asn1_cursor private_exponent;
587 	struct asn1_cursor private_cursor;
588 	struct asn1_cursor public_modulus;
589 	struct asn1_cursor public_exponent;
590 	struct asn1_cursor public_cursor;
591 	int rc;
592 
593 	/* Initialise cursors */
594 	private_cursor.data = private_key;
595 	private_cursor.len = private_key_len;
596 	public_cursor.data = public_key;
597 	public_cursor.len = public_key_len;
598 
599 	/* Parse moduli and exponents */
600 	if ( ( rc = rsa_parse_mod_exp ( &private_modulus, &private_exponent,
601 					&private_cursor ) ) != 0 )
602 		return rc;
603 	if ( ( rc = rsa_parse_mod_exp ( &public_modulus, &public_exponent,
604 					&public_cursor ) ) != 0 )
605 		return rc;
606 
607 	/* Compare moduli */
608 	if ( asn1_compare ( &private_modulus, &public_modulus ) != 0 )
609 		return -ENOTTY;
610 
611 	return 0;
612 }
613 
614 /** RSA public-key algorithm */
615 struct pubkey_algorithm rsa_algorithm = {
616 	.name		= "rsa",
617 	.ctxsize	= RSA_CTX_SIZE,
618 	.init		= rsa_init,
619 	.max_len	= rsa_max_len,
620 	.encrypt	= rsa_encrypt,
621 	.decrypt	= rsa_decrypt,
622 	.sign		= rsa_sign,
623 	.verify		= rsa_verify,
624 	.final		= rsa_final,
625 	.match		= rsa_match,
626 };
627 
628 /* Drag in objects via rsa_algorithm */
629 REQUIRING_SYMBOL ( rsa_algorithm );
630 
631 /* Drag in crypto configuration */
632 REQUIRE_OBJECT ( config_crypto );
633