1 /*-------------------------------------------------------------------------
2  *
3  * encode.c
4  *	  Various data encoding/decoding things.
5  *
6  * Copyright (c) 2001-2021, PostgreSQL Global Development Group
7  *
8  *
9  * IDENTIFICATION
10  *	  src/backend/utils/adt/encode.c
11  *
12  *-------------------------------------------------------------------------
13  */
14 #include "postgres.h"
15 
16 #include <ctype.h>
17 
18 #include "mb/pg_wchar.h"
19 #include "utils/builtins.h"
20 #include "utils/memutils.h"
21 
22 
23 /*
24  * Encoding conversion API.
25  * encode_len() and decode_len() compute the amount of space needed, while
26  * encode() and decode() perform the actual conversions.  It is okay for
27  * the _len functions to return an overestimate, but not an underestimate.
28  * (Having said that, large overestimates could cause unnecessary errors,
29  * so it's better to get it right.)  The conversion routines write to the
30  * buffer at *res and return the true length of their output.
31  */
32 struct pg_encoding
33 {
34 	uint64		(*encode_len) (const char *data, size_t dlen);
35 	uint64		(*decode_len) (const char *data, size_t dlen);
36 	uint64		(*encode) (const char *data, size_t dlen, char *res);
37 	uint64		(*decode) (const char *data, size_t dlen, char *res);
38 };
39 
40 static const struct pg_encoding *pg_find_encoding(const char *name);
41 
42 /*
43  * SQL functions.
44  */
45 
46 Datum
binary_encode(PG_FUNCTION_ARGS)47 binary_encode(PG_FUNCTION_ARGS)
48 {
49 	bytea	   *data = PG_GETARG_BYTEA_PP(0);
50 	Datum		name = PG_GETARG_DATUM(1);
51 	text	   *result;
52 	char	   *namebuf;
53 	char	   *dataptr;
54 	size_t		datalen;
55 	uint64		resultlen;
56 	uint64		res;
57 	const struct pg_encoding *enc;
58 
59 	namebuf = TextDatumGetCString(name);
60 
61 	enc = pg_find_encoding(namebuf);
62 	if (enc == NULL)
63 		ereport(ERROR,
64 				(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
65 				 errmsg("unrecognized encoding: \"%s\"", namebuf)));
66 
67 	dataptr = VARDATA_ANY(data);
68 	datalen = VARSIZE_ANY_EXHDR(data);
69 
70 	resultlen = enc->encode_len(dataptr, datalen);
71 
72 	/*
73 	 * resultlen possibly overflows uint32, therefore on 32-bit machines it's
74 	 * unsafe to rely on palloc's internal check.
75 	 */
76 	if (resultlen > MaxAllocSize - VARHDRSZ)
77 		ereport(ERROR,
78 				(errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED),
79 				 errmsg("result of encoding conversion is too large")));
80 
81 	result = palloc(VARHDRSZ + resultlen);
82 
83 	res = enc->encode(dataptr, datalen, VARDATA(result));
84 
85 	/* Make this FATAL 'cause we've trodden on memory ... */
86 	if (res > resultlen)
87 		elog(FATAL, "overflow - encode estimate too small");
88 
89 	SET_VARSIZE(result, VARHDRSZ + res);
90 
91 	PG_RETURN_TEXT_P(result);
92 }
93 
94 Datum
binary_decode(PG_FUNCTION_ARGS)95 binary_decode(PG_FUNCTION_ARGS)
96 {
97 	text	   *data = PG_GETARG_TEXT_PP(0);
98 	Datum		name = PG_GETARG_DATUM(1);
99 	bytea	   *result;
100 	char	   *namebuf;
101 	char	   *dataptr;
102 	size_t		datalen;
103 	uint64		resultlen;
104 	uint64		res;
105 	const struct pg_encoding *enc;
106 
107 	namebuf = TextDatumGetCString(name);
108 
109 	enc = pg_find_encoding(namebuf);
110 	if (enc == NULL)
111 		ereport(ERROR,
112 				(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
113 				 errmsg("unrecognized encoding: \"%s\"", namebuf)));
114 
115 	dataptr = VARDATA_ANY(data);
116 	datalen = VARSIZE_ANY_EXHDR(data);
117 
118 	resultlen = enc->decode_len(dataptr, datalen);
119 
120 	/*
121 	 * resultlen possibly overflows uint32, therefore on 32-bit machines it's
122 	 * unsafe to rely on palloc's internal check.
123 	 */
124 	if (resultlen > MaxAllocSize - VARHDRSZ)
125 		ereport(ERROR,
126 				(errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED),
127 				 errmsg("result of decoding conversion is too large")));
128 
129 	result = palloc(VARHDRSZ + resultlen);
130 
131 	res = enc->decode(dataptr, datalen, VARDATA(result));
132 
133 	/* Make this FATAL 'cause we've trodden on memory ... */
134 	if (res > resultlen)
135 		elog(FATAL, "overflow - decode estimate too small");
136 
137 	SET_VARSIZE(result, VARHDRSZ + res);
138 
139 	PG_RETURN_BYTEA_P(result);
140 }
141 
142 
143 /*
144  * HEX
145  */
146 
147 static const char hextbl[] = "0123456789abcdef";
148 
149 static const int8 hexlookup[128] = {
150 	-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
151 	-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
152 	-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
153 	0, 1, 2, 3, 4, 5, 6, 7, 8, 9, -1, -1, -1, -1, -1, -1,
154 	-1, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1,
155 	-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
156 	-1, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1,
157 	-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
158 };
159 
160 uint64
hex_encode(const char * src,size_t len,char * dst)161 hex_encode(const char *src, size_t len, char *dst)
162 {
163 	const char *end = src + len;
164 
165 	while (src < end)
166 	{
167 		*dst++ = hextbl[(*src >> 4) & 0xF];
168 		*dst++ = hextbl[*src & 0xF];
169 		src++;
170 	}
171 	return (uint64) len * 2;
172 }
173 
174 static inline char
get_hex(const char * cp)175 get_hex(const char *cp)
176 {
177 	unsigned char c = (unsigned char) *cp;
178 	int			res = -1;
179 
180 	if (c < 127)
181 		res = hexlookup[c];
182 
183 	if (res < 0)
184 		ereport(ERROR,
185 				(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
186 				 errmsg("invalid hexadecimal digit: \"%.*s\"",
187 						pg_mblen(cp), cp)));
188 
189 	return (char) res;
190 }
191 
192 uint64
hex_decode(const char * src,size_t len,char * dst)193 hex_decode(const char *src, size_t len, char *dst)
194 {
195 	const char *s,
196 			   *srcend;
197 	char		v1,
198 				v2,
199 			   *p;
200 
201 	srcend = src + len;
202 	s = src;
203 	p = dst;
204 	while (s < srcend)
205 	{
206 		if (*s == ' ' || *s == '\n' || *s == '\t' || *s == '\r')
207 		{
208 			s++;
209 			continue;
210 		}
211 		v1 = get_hex(s) << 4;
212 		s++;
213 		if (s >= srcend)
214 			ereport(ERROR,
215 					(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
216 					 errmsg("invalid hexadecimal data: odd number of digits")));
217 
218 		v2 = get_hex(s);
219 		s++;
220 		*p++ = v1 | v2;
221 	}
222 
223 	return p - dst;
224 }
225 
226 static uint64
hex_enc_len(const char * src,size_t srclen)227 hex_enc_len(const char *src, size_t srclen)
228 {
229 	return (uint64) srclen << 1;
230 }
231 
232 static uint64
hex_dec_len(const char * src,size_t srclen)233 hex_dec_len(const char *src, size_t srclen)
234 {
235 	return (uint64) srclen >> 1;
236 }
237 
238 /*
239  * BASE64
240  */
241 
242 static const char _base64[] =
243 "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
244 
245 static const int8 b64lookup[128] = {
246 	-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
247 	-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
248 	-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 62, -1, -1, -1, 63,
249 	52, 53, 54, 55, 56, 57, 58, 59, 60, 61, -1, -1, -1, -1, -1, -1,
250 	-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
251 	15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, -1, -1, -1, -1, -1,
252 	-1, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40,
253 	41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, -1, -1, -1, -1, -1,
254 };
255 
256 static uint64
pg_base64_encode(const char * src,size_t len,char * dst)257 pg_base64_encode(const char *src, size_t len, char *dst)
258 {
259 	char	   *p,
260 			   *lend = dst + 76;
261 	const char *s,
262 			   *end = src + len;
263 	int			pos = 2;
264 	uint32		buf = 0;
265 
266 	s = src;
267 	p = dst;
268 
269 	while (s < end)
270 	{
271 		buf |= (unsigned char) *s << (pos << 3);
272 		pos--;
273 		s++;
274 
275 		/* write it out */
276 		if (pos < 0)
277 		{
278 			*p++ = _base64[(buf >> 18) & 0x3f];
279 			*p++ = _base64[(buf >> 12) & 0x3f];
280 			*p++ = _base64[(buf >> 6) & 0x3f];
281 			*p++ = _base64[buf & 0x3f];
282 
283 			pos = 2;
284 			buf = 0;
285 		}
286 		if (p >= lend)
287 		{
288 			*p++ = '\n';
289 			lend = p + 76;
290 		}
291 	}
292 	if (pos != 2)
293 	{
294 		*p++ = _base64[(buf >> 18) & 0x3f];
295 		*p++ = _base64[(buf >> 12) & 0x3f];
296 		*p++ = (pos == 0) ? _base64[(buf >> 6) & 0x3f] : '=';
297 		*p++ = '=';
298 	}
299 
300 	return p - dst;
301 }
302 
303 static uint64
pg_base64_decode(const char * src,size_t len,char * dst)304 pg_base64_decode(const char *src, size_t len, char *dst)
305 {
306 	const char *srcend = src + len,
307 			   *s = src;
308 	char	   *p = dst;
309 	char		c;
310 	int			b = 0;
311 	uint32		buf = 0;
312 	int			pos = 0,
313 				end = 0;
314 
315 	while (s < srcend)
316 	{
317 		c = *s++;
318 
319 		if (c == ' ' || c == '\t' || c == '\n' || c == '\r')
320 			continue;
321 
322 		if (c == '=')
323 		{
324 			/* end sequence */
325 			if (!end)
326 			{
327 				if (pos == 2)
328 					end = 1;
329 				else if (pos == 3)
330 					end = 2;
331 				else
332 					ereport(ERROR,
333 							(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
334 							 errmsg("unexpected \"=\" while decoding base64 sequence")));
335 			}
336 			b = 0;
337 		}
338 		else
339 		{
340 			b = -1;
341 			if (c > 0 && c < 127)
342 				b = b64lookup[(unsigned char) c];
343 			if (b < 0)
344 				ereport(ERROR,
345 						(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
346 						 errmsg("invalid symbol \"%.*s\" found while decoding base64 sequence",
347 								pg_mblen(s - 1), s - 1)));
348 		}
349 		/* add it to buffer */
350 		buf = (buf << 6) + b;
351 		pos++;
352 		if (pos == 4)
353 		{
354 			*p++ = (buf >> 16) & 255;
355 			if (end == 0 || end > 1)
356 				*p++ = (buf >> 8) & 255;
357 			if (end == 0 || end > 2)
358 				*p++ = buf & 255;
359 			buf = 0;
360 			pos = 0;
361 		}
362 	}
363 
364 	if (pos != 0)
365 		ereport(ERROR,
366 				(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
367 				 errmsg("invalid base64 end sequence"),
368 				 errhint("Input data is missing padding, is truncated, or is otherwise corrupted.")));
369 
370 	return p - dst;
371 }
372 
373 
374 static uint64
pg_base64_enc_len(const char * src,size_t srclen)375 pg_base64_enc_len(const char *src, size_t srclen)
376 {
377 	/* 3 bytes will be converted to 4, linefeed after 76 chars */
378 	return ((uint64) srclen + 2) * 4 / 3 + (uint64) srclen / (76 * 3 / 4);
379 }
380 
381 static uint64
pg_base64_dec_len(const char * src,size_t srclen)382 pg_base64_dec_len(const char *src, size_t srclen)
383 {
384 	return ((uint64) srclen * 3) >> 2;
385 }
386 
387 /*
388  * Escape
389  * Minimally escape bytea to text.
390  * De-escape text to bytea.
391  *
392  * We must escape zero bytes and high-bit-set bytes to avoid generating
393  * text that might be invalid in the current encoding, or that might
394  * change to something else if passed through an encoding conversion
395  * (leading to failing to de-escape to the original bytea value).
396  * Also of course backslash itself has to be escaped.
397  *
398  * De-escaping processes \\ and any \### octal
399  */
400 
401 #define VAL(CH)			((CH) - '0')
402 #define DIG(VAL)		((VAL) + '0')
403 
404 static uint64
esc_encode(const char * src,size_t srclen,char * dst)405 esc_encode(const char *src, size_t srclen, char *dst)
406 {
407 	const char *end = src + srclen;
408 	char	   *rp = dst;
409 	uint64		len = 0;
410 
411 	while (src < end)
412 	{
413 		unsigned char c = (unsigned char) *src;
414 
415 		if (c == '\0' || IS_HIGHBIT_SET(c))
416 		{
417 			rp[0] = '\\';
418 			rp[1] = DIG(c >> 6);
419 			rp[2] = DIG((c >> 3) & 7);
420 			rp[3] = DIG(c & 7);
421 			rp += 4;
422 			len += 4;
423 		}
424 		else if (c == '\\')
425 		{
426 			rp[0] = '\\';
427 			rp[1] = '\\';
428 			rp += 2;
429 			len += 2;
430 		}
431 		else
432 		{
433 			*rp++ = c;
434 			len++;
435 		}
436 
437 		src++;
438 	}
439 
440 	return len;
441 }
442 
443 static uint64
esc_decode(const char * src,size_t srclen,char * dst)444 esc_decode(const char *src, size_t srclen, char *dst)
445 {
446 	const char *end = src + srclen;
447 	char	   *rp = dst;
448 	uint64		len = 0;
449 
450 	while (src < end)
451 	{
452 		if (src[0] != '\\')
453 			*rp++ = *src++;
454 		else if (src + 3 < end &&
455 				 (src[1] >= '0' && src[1] <= '3') &&
456 				 (src[2] >= '0' && src[2] <= '7') &&
457 				 (src[3] >= '0' && src[3] <= '7'))
458 		{
459 			int			val;
460 
461 			val = VAL(src[1]);
462 			val <<= 3;
463 			val += VAL(src[2]);
464 			val <<= 3;
465 			*rp++ = val + VAL(src[3]);
466 			src += 4;
467 		}
468 		else if (src + 1 < end &&
469 				 (src[1] == '\\'))
470 		{
471 			*rp++ = '\\';
472 			src += 2;
473 		}
474 		else
475 		{
476 			/*
477 			 * One backslash, not followed by ### valid octal. Should never
478 			 * get here, since esc_dec_len does same check.
479 			 */
480 			ereport(ERROR,
481 					(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
482 					 errmsg("invalid input syntax for type %s", "bytea")));
483 		}
484 
485 		len++;
486 	}
487 
488 	return len;
489 }
490 
491 static uint64
esc_enc_len(const char * src,size_t srclen)492 esc_enc_len(const char *src, size_t srclen)
493 {
494 	const char *end = src + srclen;
495 	uint64		len = 0;
496 
497 	while (src < end)
498 	{
499 		if (*src == '\0' || IS_HIGHBIT_SET(*src))
500 			len += 4;
501 		else if (*src == '\\')
502 			len += 2;
503 		else
504 			len++;
505 
506 		src++;
507 	}
508 
509 	return len;
510 }
511 
512 static uint64
esc_dec_len(const char * src,size_t srclen)513 esc_dec_len(const char *src, size_t srclen)
514 {
515 	const char *end = src + srclen;
516 	uint64		len = 0;
517 
518 	while (src < end)
519 	{
520 		if (src[0] != '\\')
521 			src++;
522 		else if (src + 3 < end &&
523 				 (src[1] >= '0' && src[1] <= '3') &&
524 				 (src[2] >= '0' && src[2] <= '7') &&
525 				 (src[3] >= '0' && src[3] <= '7'))
526 		{
527 			/*
528 			 * backslash + valid octal
529 			 */
530 			src += 4;
531 		}
532 		else if (src + 1 < end &&
533 				 (src[1] == '\\'))
534 		{
535 			/*
536 			 * two backslashes = backslash
537 			 */
538 			src += 2;
539 		}
540 		else
541 		{
542 			/*
543 			 * one backslash, not followed by ### valid octal
544 			 */
545 			ereport(ERROR,
546 					(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
547 					 errmsg("invalid input syntax for type %s", "bytea")));
548 		}
549 
550 		len++;
551 	}
552 	return len;
553 }
554 
555 /*
556  * Common
557  */
558 
559 static const struct
560 {
561 	const char *name;
562 	struct pg_encoding enc;
563 }			enclist[] =
564 
565 {
566 	{
567 		"hex",
568 		{
569 			hex_enc_len, hex_dec_len, hex_encode, hex_decode
570 		}
571 	},
572 	{
573 		"base64",
574 		{
575 			pg_base64_enc_len, pg_base64_dec_len, pg_base64_encode, pg_base64_decode
576 		}
577 	},
578 	{
579 		"escape",
580 		{
581 			esc_enc_len, esc_dec_len, esc_encode, esc_decode
582 		}
583 	},
584 	{
585 		NULL,
586 		{
587 			NULL, NULL, NULL, NULL
588 		}
589 	}
590 };
591 
592 static const struct pg_encoding *
pg_find_encoding(const char * name)593 pg_find_encoding(const char *name)
594 {
595 	int			i;
596 
597 	for (i = 0; enclist[i].name; i++)
598 		if (pg_strcasecmp(enclist[i].name, name) == 0)
599 			return &enclist[i].enc;
600 
601 	return NULL;
602 }
603