xref: /openbsd/lib/libcrypto/sm2/sm2_crypt.c (revision 3bef86f7)
1 /*	$OpenBSD: sm2_crypt.c,v 1.2 2022/11/26 16:08:54 tb Exp $ */
2 /*
3  * Copyright (c) 2017, 2019 Ribose Inc
4  *
5  * Permission to use, copy, modify, and/or distribute this software for any
6  * purpose with or without fee is hereby granted, provided that the above
7  * copyright notice and this permission notice appear in all copies.
8  *
9  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
10  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
11  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
12  * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
13  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
14  * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
15  * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
16  */
17 
18 #ifndef OPENSSL_NO_SM2
19 
20 #include <string.h>
21 
22 #include <openssl/asn1.h>
23 #include <openssl/asn1t.h>
24 #include <openssl/bn.h>
25 #include <openssl/err.h>
26 #include <openssl/evp.h>
27 #include <openssl/sm2.h>
28 
29 #include "sm2_local.h"
30 
31 typedef struct SM2_Ciphertext_st SM2_Ciphertext;
32 
33 SM2_Ciphertext *SM2_Ciphertext_new(void);
34 void SM2_Ciphertext_free(SM2_Ciphertext *a);
35 SM2_Ciphertext *d2i_SM2_Ciphertext(SM2_Ciphertext **a, const unsigned char **in,
36     long len);
37 int i2d_SM2_Ciphertext(SM2_Ciphertext *a, unsigned char **out);
38 
39 struct SM2_Ciphertext_st {
40 	BIGNUM *C1x;
41 	BIGNUM *C1y;
42 	ASN1_OCTET_STRING *C3;
43 	ASN1_OCTET_STRING *C2;
44 };
45 
46 static const ASN1_TEMPLATE SM2_Ciphertext_seq_tt[] = {
47 	{
48 		.flags = 0,
49 		.tag = 0,
50 		.offset = offsetof(SM2_Ciphertext, C1x),
51 		.field_name = "C1x",
52 		.item = &BIGNUM_it,
53 	},
54 	{
55 		.flags = 0,
56 		.tag = 0,
57 		.offset = offsetof(SM2_Ciphertext, C1y),
58 		.field_name = "C1y",
59 		.item = &BIGNUM_it,
60 	},
61 	{
62 		.flags = 0,
63 		.tag = 0,
64 		.offset = offsetof(SM2_Ciphertext, C3),
65 		.field_name = "C3",
66 		.item = &ASN1_OCTET_STRING_it,
67 	},
68 	{
69 		.flags = 0,
70 		.tag = 0,
71 		.offset = offsetof(SM2_Ciphertext, C2),
72 		.field_name = "C2",
73 		.item = &ASN1_OCTET_STRING_it,
74 	},
75 };
76 
77 const ASN1_ITEM SM2_Ciphertext_it = {
78 	.itype = ASN1_ITYPE_SEQUENCE,
79 	.utype = V_ASN1_SEQUENCE,
80 	.templates = SM2_Ciphertext_seq_tt,
81 	.tcount = sizeof(SM2_Ciphertext_seq_tt) / sizeof(ASN1_TEMPLATE),
82 	.funcs = NULL,
83 	.size = sizeof(SM2_Ciphertext),
84 	.sname = "SM2_Ciphertext",
85 };
86 
87 SM2_Ciphertext *
88 d2i_SM2_Ciphertext(SM2_Ciphertext **a, const unsigned char **in, long len)
89 {
90 	return (SM2_Ciphertext *) ASN1_item_d2i((ASN1_VALUE **)a, in, len,
91 	    &SM2_Ciphertext_it);
92 }
93 
94 int
95 i2d_SM2_Ciphertext(SM2_Ciphertext *a, unsigned char **out)
96 {
97 	return ASN1_item_i2d((ASN1_VALUE *)a, out, &SM2_Ciphertext_it);
98 }
99 
100 SM2_Ciphertext *
101 SM2_Ciphertext_new(void)
102 {
103 	return (SM2_Ciphertext *)ASN1_item_new(&SM2_Ciphertext_it);
104 }
105 
106 void
107 SM2_Ciphertext_free(SM2_Ciphertext *a)
108 {
109 	ASN1_item_free((ASN1_VALUE *)a, &SM2_Ciphertext_it);
110 }
111 
112 static size_t
113 ec_field_size(const EC_GROUP *group)
114 {
115 	/* Is there some simpler way to do this? */
116 	BIGNUM *p;
117 	size_t field_size = 0;
118 
119 	if ((p = BN_new()) == NULL)
120 		goto err;
121 	if (!EC_GROUP_get_curve(group, p, NULL, NULL, NULL))
122 		goto err;
123 	field_size = BN_num_bytes(p);
124  err:
125 	BN_free(p);
126 	return field_size;
127 }
128 
129 int
130 SM2_plaintext_size(const EC_KEY *key, const EVP_MD *digest, size_t msg_len,
131     size_t *pl_size)
132 {
133 	size_t field_size, overhead;
134 	int md_size;
135 
136 	if ((field_size = ec_field_size(EC_KEY_get0_group(key))) == 0) {
137 		SM2error(SM2_R_INVALID_FIELD);
138 		return 0;
139 	}
140 
141 	if ((md_size = EVP_MD_size(digest)) < 0) {
142 		SM2error(SM2_R_INVALID_DIGEST);
143 		return 0;
144 	}
145 
146 	overhead = 10 + 2 * field_size + md_size;
147 	if (msg_len <= overhead) {
148 		SM2error(SM2_R_INVALID_ARGUMENT);
149 		return 0;
150 	}
151 
152 	*pl_size = msg_len - overhead;
153 	return 1;
154 }
155 
156 int
157 SM2_ciphertext_size(const EC_KEY *key, const EVP_MD *digest, size_t msg_len,
158     size_t *c_size)
159 {
160 	size_t asn_size, field_size;
161 	int md_size;
162 
163 	if ((field_size = ec_field_size(EC_KEY_get0_group(key))) == 0) {
164 		SM2error(SM2_R_INVALID_FIELD);
165 		return 0;
166 	}
167 
168 	if ((md_size = EVP_MD_size(digest)) < 0) {
169 		SM2error(SM2_R_INVALID_DIGEST);
170 		return 0;
171 	}
172 
173 	asn_size = 2 * ASN1_object_size(0, field_size + 1, V_ASN1_INTEGER) +
174 	    ASN1_object_size(0, md_size, V_ASN1_OCTET_STRING) +
175 	    ASN1_object_size(0, msg_len, V_ASN1_OCTET_STRING);
176 
177 	*c_size = ASN1_object_size(1, asn_size, V_ASN1_SEQUENCE);
178 	return 1;
179 }
180 
181 int
182 sm2_kdf(uint8_t *key, size_t key_len, uint8_t *secret, size_t secret_len,
183     const EVP_MD *digest)
184 {
185 	EVP_MD_CTX *hash;
186 	uint8_t *hash_buf = NULL;
187 	uint32_t ctr = 1;
188 	uint8_t ctr_buf[4] = {0};
189 	size_t hadd, hlen;
190 	int rc = 0;
191 
192 	if ((hash = EVP_MD_CTX_new()) == NULL) {
193 		SM2error(ERR_R_MALLOC_FAILURE);
194 		goto err;
195 	}
196 
197 	if ((hlen = EVP_MD_size(digest)) < 0) {
198 		SM2error(SM2_R_INVALID_DIGEST);
199 		goto err;
200 	}
201 	if ((hash_buf = malloc(hlen)) == NULL) {
202 		SM2error(ERR_R_MALLOC_FAILURE);
203 		goto err;
204 	}
205 
206 	EVP_MD_CTX_init(hash);
207 	while ((key_len > 0) && (ctr != 0)) {
208 		if (!EVP_DigestInit_ex(hash, digest, NULL)) {
209 			SM2error(ERR_R_EVP_LIB);
210 			goto err;
211 		}
212 		if (!EVP_DigestUpdate(hash, secret, secret_len)) {
213 			SM2error(ERR_R_EVP_LIB);
214 			goto err;
215 		}
216 
217 		/* big-endian counter representation */
218 		ctr_buf[0] = (ctr >> 24) & 0xff;
219 		ctr_buf[1] = (ctr >> 16) & 0xff;
220 		ctr_buf[2] = (ctr >> 8) & 0xff;
221 		ctr_buf[3] = ctr & 0xff;
222 		ctr++;
223 
224 		if (!EVP_DigestUpdate(hash, ctr_buf, 4)) {
225 			SM2error(ERR_R_EVP_LIB);
226 			goto err;
227 		}
228 		if (!EVP_DigestFinal(hash, hash_buf, NULL)) {
229 			SM2error(ERR_R_EVP_LIB);
230 			goto err;
231 		}
232 
233 		hadd = key_len > hlen ? hlen : key_len;
234 		memcpy(key, hash_buf, hadd);
235 		memset(hash_buf, 0, hlen);
236 		key_len -= hadd;
237 		key += hadd;
238 	}
239 
240 	rc = 1;
241  err:
242 	free(hash_buf);
243 	EVP_MD_CTX_free(hash);
244 	return rc;
245 }
246 
247 int
248 SM2_encrypt(const EC_KEY *key, const EVP_MD *digest, const uint8_t *msg,
249     size_t msg_len, uint8_t *ciphertext_buf, size_t *ciphertext_len)
250 {
251 	SM2_Ciphertext ctext_struct;
252 	EVP_MD_CTX *hash = NULL;
253 	BN_CTX *ctx = NULL;
254 	BIGNUM *order = NULL;
255 	BIGNUM *k, *x1, *y1, *x2, *y2;
256 	const EC_GROUP *group;
257 	const EC_POINT *P;
258 	EC_POINT *kG = NULL, *kP = NULL;
259 	uint8_t *msg_mask = NULL, *x2y2 = NULL, *C3 = NULL;
260 	size_t C3_size, field_size, i, x2size, y2size;
261 	int rc = 0;
262 	int clen;
263 
264 	ctext_struct.C2 = NULL;
265 	ctext_struct.C3 = NULL;
266 
267 	if ((hash = EVP_MD_CTX_new()) == NULL) {
268 		SM2error(ERR_R_MALLOC_FAILURE);
269 		goto err;
270 	}
271 
272 	if ((group = EC_KEY_get0_group(key)) == NULL) {
273 		SM2error(SM2_R_INVALID_KEY);
274 		goto err;
275 	}
276 
277 	if ((order = BN_new()) == NULL) {
278 		SM2error(ERR_R_MALLOC_FAILURE);
279 		goto err;
280 	}
281 
282 	if (!EC_GROUP_get_order(group, order, NULL)) {
283 		SM2error(SM2_R_INVALID_GROUP_ORDER);
284 		goto err;
285 	}
286 
287 	if ((P = EC_KEY_get0_public_key(key)) == NULL) {
288 		SM2error(SM2_R_INVALID_KEY);
289 		goto err;
290 	}
291 
292 	if ((field_size = ec_field_size(group)) == 0) {
293 		SM2error(SM2_R_INVALID_FIELD);
294 		goto err;
295 	}
296 
297 	if ((C3_size = EVP_MD_size(digest)) < 0) {
298 		SM2error(SM2_R_INVALID_DIGEST);
299 		goto err;
300 	}
301 
302 	if ((kG = EC_POINT_new(group)) == NULL) {
303 		SM2error(ERR_R_MALLOC_FAILURE);
304 		goto err;
305 	}
306 	if ((kP = EC_POINT_new(group)) == NULL) {
307 		SM2error(ERR_R_MALLOC_FAILURE);
308 		goto err;
309 	}
310 
311 	if ((ctx = BN_CTX_new()) == NULL) {
312 		SM2error(ERR_R_MALLOC_FAILURE);
313 		goto err;
314 	}
315 
316 	BN_CTX_start(ctx);
317 	if ((k = BN_CTX_get(ctx)) == NULL) {
318 		SM2error(ERR_R_BN_LIB);
319 		goto err;
320 	}
321 	if ((x1 = BN_CTX_get(ctx)) == NULL) {
322 		SM2error(ERR_R_BN_LIB);
323 		goto err;
324 	}
325 	if ((x2 = BN_CTX_get(ctx)) == NULL) {
326 		SM2error(ERR_R_BN_LIB);
327 		goto err;
328 	}
329 	if ((y1 = BN_CTX_get(ctx)) == NULL) {
330 		SM2error(ERR_R_BN_LIB);
331 		goto err;
332 	}
333 	if ((y2 = BN_CTX_get(ctx)) == NULL) {
334 		SM2error(ERR_R_BN_LIB);
335 		goto err;
336 	}
337 
338 	if ((x2y2 = calloc(2, field_size)) == NULL) {
339 		SM2error(ERR_R_MALLOC_FAILURE);
340 		goto err;
341 	}
342 
343 	if ((C3 = calloc(1, C3_size)) == NULL) {
344 		SM2error(ERR_R_MALLOC_FAILURE);
345 		goto err;
346 	}
347 
348 	memset(ciphertext_buf, 0, *ciphertext_len);
349 
350 	if (!BN_rand_range(k, order)) {
351 		SM2error(SM2_R_RANDOM_NUMBER_GENERATION_FAILED);
352 		goto err;
353 	}
354 
355 	if (!EC_POINT_mul(group, kG, k, NULL, NULL, ctx)) {
356 		SM2error(ERR_R_EC_LIB);
357 		goto err;
358 	}
359 
360 	if (!EC_POINT_get_affine_coordinates(group, kG, x1, y1, ctx)) {
361 		SM2error(ERR_R_EC_LIB);
362 		goto err;
363 	}
364 
365 	if (!EC_POINT_mul(group, kP, NULL, P, k, ctx)) {
366 		SM2error(ERR_R_EC_LIB);
367 		goto err;
368 	}
369 
370 	if (!EC_POINT_get_affine_coordinates(group, kP, x2, y2, ctx)) {
371 		SM2error(ERR_R_EC_LIB);
372 		goto err;
373 	}
374 
375 	if ((x2size = BN_num_bytes(x2)) > field_size ||
376 	    (y2size = BN_num_bytes(y2)) > field_size) {
377 		SM2error(SM2_R_BIGNUM_OUT_OF_RANGE);
378 		goto err;
379 	}
380 
381 	BN_bn2bin(x2, x2y2 + field_size - x2size);
382 	BN_bn2bin(y2, x2y2 + 2 * field_size - y2size);
383 
384 	if ((msg_mask = calloc(1, msg_len)) == NULL) {
385 		SM2error(ERR_R_MALLOC_FAILURE);
386 		goto err;
387 	}
388 
389 	if (!sm2_kdf(msg_mask, msg_len, x2y2, 2 * field_size, digest)) {
390 		SM2error(SM2_R_KDF_FAILURE);
391 		goto err;
392 	}
393 
394 	for (i = 0; i != msg_len; i++)
395 		msg_mask[i] ^= msg[i];
396 
397 	if (!EVP_DigestInit(hash, digest)) {
398 		SM2error(ERR_R_EVP_LIB);
399 		goto err;
400 	}
401 
402 	if (!EVP_DigestUpdate(hash, x2y2, field_size)) {
403 		SM2error(ERR_R_EVP_LIB);
404 		goto err;
405 	}
406 
407 	if (!EVP_DigestUpdate(hash, msg, msg_len)) {
408 		SM2error(ERR_R_EVP_LIB);
409 		goto err;
410 	}
411 
412 	if (!EVP_DigestUpdate(hash, x2y2 + field_size, field_size)) {
413 		SM2error(ERR_R_EVP_LIB);
414 		goto err;
415 	}
416 
417 	if (!EVP_DigestFinal(hash, C3, NULL)) {
418 		SM2error(ERR_R_EVP_LIB);
419 		goto err;
420 	}
421 
422 	ctext_struct.C1x = x1;
423 	ctext_struct.C1y = y1;
424 	if ((ctext_struct.C3 = ASN1_OCTET_STRING_new()) == NULL) {
425 		SM2error(ERR_R_MALLOC_FAILURE);
426 		goto err;
427 	}
428 	if ((ctext_struct.C2 = ASN1_OCTET_STRING_new()) == NULL) {
429 		SM2error(ERR_R_MALLOC_FAILURE);
430 		goto err;
431 	}
432 	if (!ASN1_OCTET_STRING_set(ctext_struct.C3, C3, C3_size)) {
433 		SM2error(ERR_R_INTERNAL_ERROR);
434 		goto err;
435 	}
436 	if (!ASN1_OCTET_STRING_set(ctext_struct.C2, msg_mask, msg_len)) {
437 		SM2error(ERR_R_INTERNAL_ERROR);
438 		goto err;
439 	}
440 
441 	if ((clen = i2d_SM2_Ciphertext(&ctext_struct, &ciphertext_buf)) < 0) {
442 		SM2error(ERR_R_INTERNAL_ERROR);
443 		goto err;
444 	}
445 
446 	*ciphertext_len = clen;
447 	rc = 1;
448 
449  err:
450 	ASN1_OCTET_STRING_free(ctext_struct.C2);
451 	ASN1_OCTET_STRING_free(ctext_struct.C3);
452 	free(msg_mask);
453 	free(x2y2);
454 	free(C3);
455 	EVP_MD_CTX_free(hash);
456 	BN_CTX_end(ctx);
457 	BN_CTX_free(ctx);
458 	EC_POINT_free(kG);
459 	EC_POINT_free(kP);
460 	BN_free(order);
461 	return rc;
462 }
463 
464 int
465 SM2_decrypt(const EC_KEY *key, const EVP_MD *digest, const uint8_t *ciphertext,
466     size_t ciphertext_len, uint8_t *ptext_buf, size_t *ptext_len)
467 {
468 	SM2_Ciphertext *sm2_ctext = NULL;
469 	EVP_MD_CTX *hash = NULL;
470 	BN_CTX *ctx = NULL;
471 	BIGNUM *x2, *y2;
472 	const EC_GROUP *group;
473 	EC_POINT *C1 = NULL;
474 	const uint8_t *C2, *C3;
475 	uint8_t *computed_C3 = NULL, *msg_mask = NULL, *x2y2 = NULL;
476 	size_t field_size, x2size, y2size;
477 	int msg_len = 0, rc = 0;
478 	int hash_size, i;
479 
480 	if ((group = EC_KEY_get0_group(key)) == NULL) {
481 		SM2error(SM2_R_INVALID_KEY);
482 		goto err;
483 	}
484 
485 	if ((field_size = ec_field_size(group)) == 0) {
486 		SM2error(SM2_R_INVALID_FIELD);
487 		goto err;
488 	}
489 
490 	if ((hash_size = EVP_MD_size(digest)) < 0) {
491 		SM2error(SM2_R_INVALID_DIGEST);
492 		goto err;
493 	}
494 
495 	memset(ptext_buf, 0xFF, *ptext_len);
496 
497 	if ((sm2_ctext = d2i_SM2_Ciphertext(NULL, &ciphertext,
498 	    ciphertext_len)) == NULL) {
499 		SM2error(SM2_R_ASN1_ERROR);
500 		goto err;
501 	}
502 
503 	if (sm2_ctext->C3->length != hash_size) {
504 		SM2error(SM2_R_INVALID_ENCODING);
505 		goto err;
506 	}
507 
508 	C2 = sm2_ctext->C2->data;
509 	C3 = sm2_ctext->C3->data;
510 	msg_len = sm2_ctext->C2->length;
511 
512 	if ((ctx = BN_CTX_new()) == NULL) {
513 		SM2error(ERR_R_MALLOC_FAILURE);
514 		goto err;
515 	}
516 
517 	BN_CTX_start(ctx);
518 	if ((x2 = BN_CTX_get(ctx)) == NULL) {
519 		SM2error(ERR_R_BN_LIB);
520 		goto err;
521 	}
522 	if ((y2 = BN_CTX_get(ctx)) == NULL) {
523 		SM2error(ERR_R_BN_LIB);
524 		goto err;
525 	}
526 
527 	if ((msg_mask = calloc(1, msg_len)) == NULL) {
528 		SM2error(ERR_R_MALLOC_FAILURE);
529 		goto err;
530 	}
531 	if ((x2y2 = calloc(2, field_size)) == NULL) {
532 		SM2error(ERR_R_MALLOC_FAILURE);
533 		goto err;
534 	}
535 	if ((computed_C3 = calloc(1, hash_size)) == NULL) {
536 		SM2error(ERR_R_MALLOC_FAILURE);
537 		goto err;
538 	}
539 
540 	if ((C1 = EC_POINT_new(group)) == NULL) {
541 		SM2error(ERR_R_MALLOC_FAILURE);
542 		goto err;
543 	}
544 
545 	if (!EC_POINT_set_affine_coordinates(group, C1, sm2_ctext->C1x,
546 	    sm2_ctext->C1y, ctx))
547 	{
548 		SM2error(ERR_R_EC_LIB);
549 		goto err;
550 	}
551 
552 	if (!EC_POINT_mul(group, C1, NULL, C1, EC_KEY_get0_private_key(key),
553 	    ctx)) {
554 		SM2error(ERR_R_EC_LIB);
555 		goto err;
556 	}
557 
558 	if (!EC_POINT_get_affine_coordinates(group, C1, x2, y2, ctx)) {
559 		SM2error(ERR_R_EC_LIB);
560 		goto err;
561 	}
562 
563 	if ((x2size = BN_num_bytes(x2)) > field_size ||
564 	    (y2size = BN_num_bytes(y2)) > field_size) {
565 		SM2error(SM2_R_BIGNUM_OUT_OF_RANGE);
566 		goto err;
567 	}
568 
569 	BN_bn2bin(x2, x2y2 + field_size - x2size);
570 	BN_bn2bin(y2, x2y2 + 2 * field_size - y2size);
571 
572 	if (!sm2_kdf(msg_mask, msg_len, x2y2, 2 * field_size, digest)) {
573 		SM2error(SM2_R_KDF_FAILURE);
574 		goto err;
575 	}
576 
577 	for (i = 0; i != msg_len; ++i)
578 		ptext_buf[i] = C2[i] ^ msg_mask[i];
579 
580 	if ((hash = EVP_MD_CTX_new()) == NULL) {
581 		SM2error(ERR_R_EVP_LIB);
582 		goto err;
583 	}
584 
585 	if (!EVP_DigestInit(hash, digest)) {
586 		SM2error(ERR_R_EVP_LIB);
587 		goto err;
588 	}
589 
590 	if (!EVP_DigestUpdate(hash, x2y2, field_size)) {
591 		SM2error(ERR_R_EVP_LIB);
592 		goto err;
593 	}
594 
595 	if (!EVP_DigestUpdate(hash, ptext_buf, msg_len)) {
596 		SM2error(ERR_R_EVP_LIB);
597 		goto err;
598 	}
599 
600 	if (!EVP_DigestUpdate(hash, x2y2 + field_size, field_size)) {
601 		SM2error(ERR_R_EVP_LIB);
602 		goto err;
603 	}
604 
605 	if (!EVP_DigestFinal(hash, computed_C3, NULL)) {
606 		SM2error(ERR_R_EVP_LIB);
607 		goto err;
608 	}
609 
610 	if (memcmp(computed_C3, C3, hash_size) != 0)
611 		goto err;
612 
613 	rc = 1;
614 	*ptext_len = msg_len;
615 
616  err:
617 	if (rc == 0)
618 		memset(ptext_buf, 0, *ptext_len);
619 
620 	free(msg_mask);
621 	free(x2y2);
622 	free(computed_C3);
623 	EC_POINT_free(C1);
624 	BN_CTX_end(ctx);
625 	BN_CTX_free(ctx);
626 	SM2_Ciphertext_free(sm2_ctext);
627 	EVP_MD_CTX_free(hash);
628 
629 	return rc;
630 }
631 
632 #endif /* OPENSSL_NO_SM2 */
633