1 // SPDX-License-Identifier: GPL-2.0
2 
3 #include <linux/errno.h>
4 
5 int ceph_armor(char *dst, const char *src, const char *end);
6 int ceph_unarmor(char *dst, const char *src, const char *end);
7 
8 /*
9  * base64 encode/decode.
10  */
11 
12 static const char *pem_key =
13 	"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
14 
encode_bits(int c)15 static int encode_bits(int c)
16 {
17 	return pem_key[c];
18 }
19 
decode_bits(char c)20 static int decode_bits(char c)
21 {
22 	if (c >= 'A' && c <= 'Z')
23 		return c - 'A';
24 	if (c >= 'a' && c <= 'z')
25 		return c - 'a' + 26;
26 	if (c >= '0' && c <= '9')
27 		return c - '0' + 52;
28 	if (c == '+')
29 		return 62;
30 	if (c == '/')
31 		return 63;
32 	if (c == '=')
33 		return 0; /* just non-negative, please */
34 	return -EINVAL;
35 }
36 
ceph_armor(char * dst,const char * src,const char * end)37 int ceph_armor(char *dst, const char *src, const char *end)
38 {
39 	int olen = 0;
40 	int line = 0;
41 
42 	while (src < end) {
43 		unsigned char a, b, c;
44 
45 		a = *src++;
46 		*dst++ = encode_bits(a >> 2);
47 		if (src < end) {
48 			b = *src++;
49 			*dst++ = encode_bits(((a & 3) << 4) | (b >> 4));
50 			if (src < end) {
51 				c = *src++;
52 				*dst++ = encode_bits(((b & 15) << 2) |
53 						     (c >> 6));
54 				*dst++ = encode_bits(c & 63);
55 			} else {
56 				*dst++ = encode_bits((b & 15) << 2);
57 				*dst++ = '=';
58 			}
59 		} else {
60 			*dst++ = encode_bits(((a & 3) << 4));
61 			*dst++ = '=';
62 			*dst++ = '=';
63 		}
64 		olen += 4;
65 		line += 4;
66 		if (line == 64) {
67 			line = 0;
68 			*(dst++) = '\n';
69 			olen++;
70 		}
71 	}
72 	return olen;
73 }
74 
ceph_unarmor(char * dst,const char * src,const char * end)75 int ceph_unarmor(char *dst, const char *src, const char *end)
76 {
77 	int olen = 0;
78 
79 	while (src < end) {
80 		int a, b, c, d;
81 
82 		if (src[0] == '\n') {
83 			src++;
84 			continue;
85 		}
86 		if (src + 4 > end)
87 			return -EINVAL;
88 		a = decode_bits(src[0]);
89 		b = decode_bits(src[1]);
90 		c = decode_bits(src[2]);
91 		d = decode_bits(src[3]);
92 		if (a < 0 || b < 0 || c < 0 || d < 0)
93 			return -EINVAL;
94 
95 		*dst++ = (a << 2) | (b >> 4);
96 		if (src[2] == '=')
97 			return olen + 1;
98 		*dst++ = ((b & 15) << 4) | (c >> 2);
99 		if (src[3] == '=')
100 			return olen + 2;
101 		*dst++ = ((c & 3) << 6) | d;
102 		olen += 3;
103 		src += 4;
104 	}
105 	return olen;
106 }
107