1 /**
2  * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3  * SPDX-License-Identifier: Apache-2.0.
4  */
5 
6 #include <aws/common/encoding.h>
7 
8 #include <ctype.h>
9 #include <stdlib.h>
10 
11 #ifdef USE_SIMD_ENCODING
12 size_t aws_common_private_base64_decode_sse41(const unsigned char *in, unsigned char *out, size_t len);
13 void aws_common_private_base64_encode_sse41(const unsigned char *in, unsigned char *out, size_t len);
14 bool aws_common_private_has_avx2(void);
15 #else
16 /*
17  * When AVX2 compilation is unavailable, we use these stubs to fall back to the pure-C decoder.
18  * Since we force aws_common_private_has_avx2 to return false, the encode and decode functions should
19  * not be called - but we must provide them anyway to avoid link errors.
20  */
aws_common_private_base64_decode_sse41(const unsigned char * in,unsigned char * out,size_t len)21 static inline size_t aws_common_private_base64_decode_sse41(const unsigned char *in, unsigned char *out, size_t len) {
22     (void)in;
23     (void)out;
24     (void)len;
25     AWS_ASSERT(false);
26     return (size_t)-1; /* unreachable */
27 }
aws_common_private_base64_encode_sse41(const unsigned char * in,unsigned char * out,size_t len)28 static inline void aws_common_private_base64_encode_sse41(const unsigned char *in, unsigned char *out, size_t len) {
29     (void)in;
30     (void)out;
31     (void)len;
32     AWS_ASSERT(false);
33 }
aws_common_private_has_avx2(void)34 static inline bool aws_common_private_has_avx2(void) {
35     return false;
36 }
37 #endif
38 
39 static const uint8_t *HEX_CHARS = (const uint8_t *)"0123456789abcdef";
40 
41 static const uint8_t BASE64_SENTIANAL_VALUE = 0xff;
42 static const uint8_t BASE64_ENCODING_TABLE[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
43 
44 /* in this table, 0xDD is an invalid decoded value, if you have to do byte counting for any reason, there's 16 bytes
45  * per row.  Reformatting is turned off to make sure this stays as 16 bytes per line. */
46 /* clang-format off */
47 static const uint8_t BASE64_DECODING_TABLE[256] = {
48     64,   0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD,
49     0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD,
50     0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 62,   0xDD, 0xDD, 0xDD, 63,
51     52,   53,   54,   55,   56,   57,   58,   59,   60,   61,   0xDD, 0xDD, 0xDD, 255,  0xDD, 0xDD,
52     0xDD, 0,    1,    2,    3,    4,    5,    6,    7,    8,    9,    10,   11,   12,   13,   14,
53     15,   16,   17,   18,   19,   20,   21,   22,   23,   24,   25,   0xDD, 0xDD, 0xDD, 0xDD, 0xDD,
54     0xDD, 26,   27,   28,   29,   30,   31,   32,   33,   34,   35,   36,   37,   38,   39,   40,
55     41,   42,   43,   44,   45,   46,   47,   48,   49,   50,   51,   0xDD, 0xDD, 0xDD, 0xDD, 0xDD,
56     0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD,
57     0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD,
58     0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD,
59     0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD,
60     0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD,
61     0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD,
62     0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD,
63     0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD, 0xDD};
64 /* clang-format on */
65 
aws_hex_compute_encoded_len(size_t to_encode_len,size_t * encoded_length)66 int aws_hex_compute_encoded_len(size_t to_encode_len, size_t *encoded_length) {
67     AWS_ASSERT(encoded_length);
68 
69     size_t temp = (to_encode_len << 1) + 1;
70 
71     if (AWS_UNLIKELY(temp < to_encode_len)) {
72         return aws_raise_error(AWS_ERROR_OVERFLOW_DETECTED);
73     }
74 
75     *encoded_length = temp;
76 
77     return AWS_OP_SUCCESS;
78 }
79 
aws_hex_encode(const struct aws_byte_cursor * AWS_RESTRICT to_encode,struct aws_byte_buf * AWS_RESTRICT output)80 int aws_hex_encode(const struct aws_byte_cursor *AWS_RESTRICT to_encode, struct aws_byte_buf *AWS_RESTRICT output) {
81     AWS_PRECONDITION(aws_byte_cursor_is_valid(to_encode));
82     AWS_PRECONDITION(aws_byte_buf_is_valid(output));
83 
84     size_t encoded_len = 0;
85 
86     if (AWS_UNLIKELY(aws_hex_compute_encoded_len(to_encode->len, &encoded_len))) {
87         return AWS_OP_ERR;
88     }
89 
90     if (AWS_UNLIKELY(output->capacity < encoded_len)) {
91         return aws_raise_error(AWS_ERROR_SHORT_BUFFER);
92     }
93 
94     size_t written = 0;
95     for (size_t i = 0; i < to_encode->len; ++i) {
96 
97         output->buffer[written++] = HEX_CHARS[to_encode->ptr[i] >> 4 & 0x0f];
98         output->buffer[written++] = HEX_CHARS[to_encode->ptr[i] & 0x0f];
99     }
100 
101     output->buffer[written] = '\0';
102     output->len = encoded_len;
103 
104     return AWS_OP_SUCCESS;
105 }
106 
aws_hex_encode_append_dynamic(const struct aws_byte_cursor * AWS_RESTRICT to_encode,struct aws_byte_buf * AWS_RESTRICT output)107 int aws_hex_encode_append_dynamic(
108     const struct aws_byte_cursor *AWS_RESTRICT to_encode,
109     struct aws_byte_buf *AWS_RESTRICT output) {
110     AWS_ASSERT(to_encode->ptr);
111     AWS_ASSERT(aws_byte_buf_is_valid(output));
112 
113     size_t encoded_len = 0;
114     if (AWS_UNLIKELY(aws_add_size_checked(to_encode->len, to_encode->len, &encoded_len))) {
115         return AWS_OP_ERR;
116     }
117 
118     if (AWS_UNLIKELY(aws_byte_buf_reserve_relative(output, encoded_len))) {
119         return AWS_OP_ERR;
120     }
121 
122     size_t written = output->len;
123     for (size_t i = 0; i < to_encode->len; ++i) {
124 
125         output->buffer[written++] = HEX_CHARS[to_encode->ptr[i] >> 4 & 0x0f];
126         output->buffer[written++] = HEX_CHARS[to_encode->ptr[i] & 0x0f];
127     }
128 
129     output->len += encoded_len;
130 
131     return AWS_OP_SUCCESS;
132 }
133 
s_hex_decode_char_to_int(char character,uint8_t * int_val)134 static int s_hex_decode_char_to_int(char character, uint8_t *int_val) {
135     if (character >= 'a' && character <= 'f') {
136         *int_val = (uint8_t)(10 + (character - 'a'));
137         return 0;
138     }
139 
140     if (character >= 'A' && character <= 'F') {
141         *int_val = (uint8_t)(10 + (character - 'A'));
142         return 0;
143     }
144 
145     if (character >= '0' && character <= '9') {
146         *int_val = (uint8_t)(character - '0');
147         return 0;
148     }
149 
150     return AWS_OP_ERR;
151 }
152 
aws_hex_compute_decoded_len(size_t to_decode_len,size_t * decoded_len)153 int aws_hex_compute_decoded_len(size_t to_decode_len, size_t *decoded_len) {
154     AWS_ASSERT(decoded_len);
155 
156     size_t temp = (to_decode_len + 1);
157 
158     if (AWS_UNLIKELY(temp < to_decode_len)) {
159         return aws_raise_error(AWS_ERROR_OVERFLOW_DETECTED);
160     }
161 
162     *decoded_len = temp >> 1;
163     return AWS_OP_SUCCESS;
164 }
165 
aws_hex_decode(const struct aws_byte_cursor * AWS_RESTRICT to_decode,struct aws_byte_buf * AWS_RESTRICT output)166 int aws_hex_decode(const struct aws_byte_cursor *AWS_RESTRICT to_decode, struct aws_byte_buf *AWS_RESTRICT output) {
167     AWS_PRECONDITION(aws_byte_cursor_is_valid(to_decode));
168     AWS_PRECONDITION(aws_byte_buf_is_valid(output));
169 
170     size_t decoded_length = 0;
171 
172     if (AWS_UNLIKELY(aws_hex_compute_decoded_len(to_decode->len, &decoded_length))) {
173         return aws_raise_error(AWS_ERROR_OVERFLOW_DETECTED);
174     }
175 
176     if (AWS_UNLIKELY(output->capacity < decoded_length)) {
177         return aws_raise_error(AWS_ERROR_SHORT_BUFFER);
178     }
179 
180     size_t written = 0;
181     size_t i = 0;
182     uint8_t high_value = 0;
183     uint8_t low_value = 0;
184 
185     /* if the buffer isn't even, prepend a 0 to the buffer. */
186     if (AWS_UNLIKELY(to_decode->len & 0x01)) {
187         i = 1;
188         if (s_hex_decode_char_to_int(to_decode->ptr[0], &low_value)) {
189             return aws_raise_error(AWS_ERROR_INVALID_HEX_STR);
190         }
191 
192         output->buffer[written++] = low_value;
193     }
194 
195     for (; i < to_decode->len; i += 2) {
196         if (AWS_UNLIKELY(
197                 s_hex_decode_char_to_int(to_decode->ptr[i], &high_value) ||
198                 s_hex_decode_char_to_int(to_decode->ptr[i + 1], &low_value))) {
199             return aws_raise_error(AWS_ERROR_INVALID_HEX_STR);
200         }
201 
202         uint8_t value = (uint8_t)(high_value << 4);
203         value |= low_value;
204         output->buffer[written++] = value;
205     }
206 
207     output->len = decoded_length;
208 
209     return AWS_OP_SUCCESS;
210 }
211 
aws_base64_compute_encoded_len(size_t to_encode_len,size_t * encoded_len)212 int aws_base64_compute_encoded_len(size_t to_encode_len, size_t *encoded_len) {
213     AWS_ASSERT(encoded_len);
214 
215     size_t tmp = to_encode_len + 2;
216 
217     if (AWS_UNLIKELY(tmp < to_encode_len)) {
218         return aws_raise_error(AWS_ERROR_OVERFLOW_DETECTED);
219     }
220 
221     tmp /= 3;
222     size_t overflow_check = tmp;
223     tmp = 4 * tmp + 1; /* plus one for the NULL terminator */
224 
225     if (AWS_UNLIKELY(tmp < overflow_check)) {
226         return aws_raise_error(AWS_ERROR_OVERFLOW_DETECTED);
227     }
228 
229     *encoded_len = tmp;
230 
231     return AWS_OP_SUCCESS;
232 }
233 
aws_base64_compute_decoded_len(const struct aws_byte_cursor * AWS_RESTRICT to_decode,size_t * decoded_len)234 int aws_base64_compute_decoded_len(const struct aws_byte_cursor *AWS_RESTRICT to_decode, size_t *decoded_len) {
235     AWS_ASSERT(to_decode);
236     AWS_ASSERT(decoded_len);
237 
238     const size_t len = to_decode->len;
239     const uint8_t *input = to_decode->ptr;
240 
241     if (len == 0) {
242         *decoded_len = 0;
243         return AWS_OP_SUCCESS;
244     }
245 
246     if (AWS_UNLIKELY(len & 0x03)) {
247         return aws_raise_error(AWS_ERROR_INVALID_BASE64_STR);
248     }
249 
250     size_t tmp = len * 3;
251 
252     if (AWS_UNLIKELY(tmp < len)) {
253         return aws_raise_error(AWS_ERROR_OVERFLOW_DETECTED);
254     }
255 
256     size_t padding = 0;
257 
258     if (len >= 2 && input[len - 1] == '=' && input[len - 2] == '=') { /*last two chars are = */
259         padding = 2;
260     } else if (input[len - 1] == '=') { /*last char is = */
261         padding = 1;
262     }
263 
264     *decoded_len = (tmp / 4 - padding);
265     return AWS_OP_SUCCESS;
266 }
267 
aws_base64_encode(const struct aws_byte_cursor * AWS_RESTRICT to_encode,struct aws_byte_buf * AWS_RESTRICT output)268 int aws_base64_encode(const struct aws_byte_cursor *AWS_RESTRICT to_encode, struct aws_byte_buf *AWS_RESTRICT output) {
269     AWS_ASSERT(to_encode->ptr);
270     AWS_ASSERT(output->buffer);
271 
272     size_t terminated_length = 0;
273     size_t encoded_length = 0;
274     if (AWS_UNLIKELY(aws_base64_compute_encoded_len(to_encode->len, &terminated_length))) {
275         return AWS_OP_ERR;
276     }
277 
278     size_t needed_capacity = 0;
279     if (AWS_UNLIKELY(aws_add_size_checked(output->len, terminated_length, &needed_capacity))) {
280         return AWS_OP_ERR;
281     }
282 
283     if (AWS_UNLIKELY(output->capacity < needed_capacity)) {
284         return aws_raise_error(AWS_ERROR_SHORT_BUFFER);
285     }
286 
287     /*
288      * For convenience to standard C functions expecting a null-terminated
289      * string, the output is terminated. As the encoding itself can be used in
290      * various ways, however, its length should never account for that byte.
291      */
292     encoded_length = (terminated_length - 1);
293 
294     if (aws_common_private_has_avx2()) {
295         aws_common_private_base64_encode_sse41(to_encode->ptr, output->buffer + output->len, to_encode->len);
296         output->buffer[output->len + encoded_length] = 0;
297         output->len += encoded_length;
298         return AWS_OP_SUCCESS;
299     }
300 
301     size_t buffer_length = to_encode->len;
302     size_t block_count = (buffer_length + 2) / 3;
303     size_t remainder_count = (buffer_length % 3);
304     size_t str_index = output->len;
305 
306     for (size_t i = 0; i < to_encode->len; i += 3) {
307         uint32_t block = to_encode->ptr[i];
308 
309         block <<= 8;
310         if (AWS_LIKELY(i + 1 < buffer_length)) {
311             block = block | to_encode->ptr[i + 1];
312         }
313 
314         block <<= 8;
315         if (AWS_LIKELY(i + 2 < to_encode->len)) {
316             block = block | to_encode->ptr[i + 2];
317         }
318 
319         output->buffer[str_index++] = BASE64_ENCODING_TABLE[(block >> 18) & 0x3F];
320         output->buffer[str_index++] = BASE64_ENCODING_TABLE[(block >> 12) & 0x3F];
321         output->buffer[str_index++] = BASE64_ENCODING_TABLE[(block >> 6) & 0x3F];
322         output->buffer[str_index++] = BASE64_ENCODING_TABLE[block & 0x3F];
323     }
324 
325     if (remainder_count > 0) {
326         output->buffer[output->len + block_count * 4 - 1] = '=';
327         if (remainder_count == 1) {
328             output->buffer[output->len + block_count * 4 - 2] = '=';
329         }
330     }
331 
332     /* it's a string add the null terminator. */
333     output->buffer[output->len + encoded_length] = 0;
334 
335     output->len += encoded_length;
336 
337     return AWS_OP_SUCCESS;
338 }
339 
s_base64_get_decoded_value(unsigned char to_decode,uint8_t * value,int8_t allow_sentinal)340 static inline int s_base64_get_decoded_value(unsigned char to_decode, uint8_t *value, int8_t allow_sentinal) {
341 
342     uint8_t decode_value = BASE64_DECODING_TABLE[(size_t)to_decode];
343     if (decode_value != 0xDD && (decode_value != BASE64_SENTIANAL_VALUE || allow_sentinal)) {
344         *value = decode_value;
345         return AWS_OP_SUCCESS;
346     }
347 
348     return AWS_OP_ERR;
349 }
350 
aws_base64_decode(const struct aws_byte_cursor * AWS_RESTRICT to_decode,struct aws_byte_buf * AWS_RESTRICT output)351 int aws_base64_decode(const struct aws_byte_cursor *AWS_RESTRICT to_decode, struct aws_byte_buf *AWS_RESTRICT output) {
352     size_t decoded_length = 0;
353 
354     if (AWS_UNLIKELY(aws_base64_compute_decoded_len(to_decode, &decoded_length))) {
355         return AWS_OP_ERR;
356     }
357 
358     if (output->capacity < decoded_length) {
359         return aws_raise_error(AWS_ERROR_SHORT_BUFFER);
360     }
361 
362     if (aws_common_private_has_avx2()) {
363         size_t result = aws_common_private_base64_decode_sse41(to_decode->ptr, output->buffer, to_decode->len);
364         if (result == -1) {
365             return aws_raise_error(AWS_ERROR_INVALID_BASE64_STR);
366         }
367 
368         output->len = result;
369         return AWS_OP_SUCCESS;
370     }
371 
372     int64_t block_count = (int64_t)to_decode->len / 4;
373     size_t string_index = 0;
374     uint8_t value1 = 0, value2 = 0, value3 = 0, value4 = 0;
375     int64_t buffer_index = 0;
376 
377     for (int64_t i = 0; i < block_count - 1; ++i) {
378         if (AWS_UNLIKELY(
379                 s_base64_get_decoded_value(to_decode->ptr[string_index++], &value1, 0) ||
380                 s_base64_get_decoded_value(to_decode->ptr[string_index++], &value2, 0) ||
381                 s_base64_get_decoded_value(to_decode->ptr[string_index++], &value3, 0) ||
382                 s_base64_get_decoded_value(to_decode->ptr[string_index++], &value4, 0))) {
383             return aws_raise_error(AWS_ERROR_INVALID_BASE64_STR);
384         }
385 
386         buffer_index = i * 3;
387         output->buffer[buffer_index++] = (uint8_t)((value1 << 2) | ((value2 >> 4) & 0x03));
388         output->buffer[buffer_index++] = (uint8_t)(((value2 << 4) & 0xF0) | ((value3 >> 2) & 0x0F));
389         output->buffer[buffer_index] = (uint8_t)((value3 & 0x03) << 6 | value4);
390     }
391 
392     buffer_index = (block_count - 1) * 3;
393 
394     if (buffer_index >= 0) {
395         if (s_base64_get_decoded_value(to_decode->ptr[string_index++], &value1, 0) ||
396             s_base64_get_decoded_value(to_decode->ptr[string_index++], &value2, 0) ||
397             s_base64_get_decoded_value(to_decode->ptr[string_index++], &value3, 1) ||
398             s_base64_get_decoded_value(to_decode->ptr[string_index], &value4, 1)) {
399             return aws_raise_error(AWS_ERROR_INVALID_BASE64_STR);
400         }
401 
402         output->buffer[buffer_index++] = (uint8_t)((value1 << 2) | ((value2 >> 4) & 0x03));
403 
404         if (value3 != BASE64_SENTIANAL_VALUE) {
405             output->buffer[buffer_index++] = (uint8_t)(((value2 << 4) & 0xF0) | ((value3 >> 2) & 0x0F));
406             if (value4 != BASE64_SENTIANAL_VALUE) {
407                 output->buffer[buffer_index] = (uint8_t)((value3 & 0x03) << 6 | value4);
408             }
409         }
410     }
411     output->len = decoded_length;
412     return AWS_OP_SUCCESS;
413 }
414