1 /*-------------------------------------------------------------------------
2  *
3  * encode.c
4  *	  Various data encoding/decoding things.
5  *
6  * Copyright (c) 2001-2018, PostgreSQL Global Development Group
7  *
8  *
9  * IDENTIFICATION
10  *	  src/backend/utils/adt/encode.c
11  *
12  *-------------------------------------------------------------------------
13  */
14 #include "postgres.h"
15 
16 #include <ctype.h>
17 
18 #include "utils/builtins.h"
19 
20 
21 struct pg_encoding
22 {
23 	unsigned	(*encode_len) (const char *data, unsigned dlen);
24 	unsigned	(*decode_len) (const char *data, unsigned dlen);
25 	unsigned	(*encode) (const char *data, unsigned dlen, char *res);
26 	unsigned	(*decode) (const char *data, unsigned dlen, char *res);
27 };
28 
29 static const struct pg_encoding *pg_find_encoding(const char *name);
30 
31 /*
32  * SQL functions.
33  */
34 
35 Datum
36 binary_encode(PG_FUNCTION_ARGS)
37 {
38 	bytea	   *data = PG_GETARG_BYTEA_PP(0);
39 	Datum		name = PG_GETARG_DATUM(1);
40 	text	   *result;
41 	char	   *namebuf;
42 	int			datalen,
43 				resultlen,
44 				res;
45 	const struct pg_encoding *enc;
46 
47 	datalen = VARSIZE_ANY_EXHDR(data);
48 
49 	namebuf = TextDatumGetCString(name);
50 
51 	enc = pg_find_encoding(namebuf);
52 	if (enc == NULL)
53 		ereport(ERROR,
54 				(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
55 				 errmsg("unrecognized encoding: \"%s\"", namebuf)));
56 
57 	resultlen = enc->encode_len(VARDATA_ANY(data), datalen);
58 	result = palloc(VARHDRSZ + resultlen);
59 
60 	res = enc->encode(VARDATA_ANY(data), datalen, VARDATA(result));
61 
62 	/* Make this FATAL 'cause we've trodden on memory ... */
63 	if (res > resultlen)
64 		elog(FATAL, "overflow - encode estimate too small");
65 
66 	SET_VARSIZE(result, VARHDRSZ + res);
67 
68 	PG_RETURN_TEXT_P(result);
69 }
70 
71 Datum
72 binary_decode(PG_FUNCTION_ARGS)
73 {
74 	text	   *data = PG_GETARG_TEXT_PP(0);
75 	Datum		name = PG_GETARG_DATUM(1);
76 	bytea	   *result;
77 	char	   *namebuf;
78 	int			datalen,
79 				resultlen,
80 				res;
81 	const struct pg_encoding *enc;
82 
83 	datalen = VARSIZE_ANY_EXHDR(data);
84 
85 	namebuf = TextDatumGetCString(name);
86 
87 	enc = pg_find_encoding(namebuf);
88 	if (enc == NULL)
89 		ereport(ERROR,
90 				(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
91 				 errmsg("unrecognized encoding: \"%s\"", namebuf)));
92 
93 	resultlen = enc->decode_len(VARDATA_ANY(data), datalen);
94 	result = palloc(VARHDRSZ + resultlen);
95 
96 	res = enc->decode(VARDATA_ANY(data), datalen, VARDATA(result));
97 
98 	/* Make this FATAL 'cause we've trodden on memory ... */
99 	if (res > resultlen)
100 		elog(FATAL, "overflow - decode estimate too small");
101 
102 	SET_VARSIZE(result, VARHDRSZ + res);
103 
104 	PG_RETURN_BYTEA_P(result);
105 }
106 
107 
108 /*
109  * HEX
110  */
111 
112 static const char hextbl[] = "0123456789abcdef";
113 
114 static const int8 hexlookup[128] = {
115 	-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
116 	-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
117 	-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
118 	0, 1, 2, 3, 4, 5, 6, 7, 8, 9, -1, -1, -1, -1, -1, -1,
119 	-1, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1,
120 	-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
121 	-1, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1,
122 	-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
123 };
124 
125 unsigned
126 hex_encode(const char *src, unsigned len, char *dst)
127 {
128 	const char *end = src + len;
129 
130 	while (src < end)
131 	{
132 		*dst++ = hextbl[(*src >> 4) & 0xF];
133 		*dst++ = hextbl[*src & 0xF];
134 		src++;
135 	}
136 	return len * 2;
137 }
138 
139 static inline char
140 get_hex(char c)
141 {
142 	int			res = -1;
143 
144 	if (c > 0 && c < 127)
145 		res = hexlookup[(unsigned char) c];
146 
147 	if (res < 0)
148 		ereport(ERROR,
149 				(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
150 				 errmsg("invalid hexadecimal digit: \"%c\"", c)));
151 
152 	return (char) res;
153 }
154 
155 unsigned
156 hex_decode(const char *src, unsigned len, char *dst)
157 {
158 	const char *s,
159 			   *srcend;
160 	char		v1,
161 				v2,
162 			   *p;
163 
164 	srcend = src + len;
165 	s = src;
166 	p = dst;
167 	while (s < srcend)
168 	{
169 		if (*s == ' ' || *s == '\n' || *s == '\t' || *s == '\r')
170 		{
171 			s++;
172 			continue;
173 		}
174 		v1 = get_hex(*s++) << 4;
175 		if (s >= srcend)
176 			ereport(ERROR,
177 					(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
178 					 errmsg("invalid hexadecimal data: odd number of digits")));
179 
180 		v2 = get_hex(*s++);
181 		*p++ = v1 | v2;
182 	}
183 
184 	return p - dst;
185 }
186 
187 static unsigned
188 hex_enc_len(const char *src, unsigned srclen)
189 {
190 	return srclen << 1;
191 }
192 
193 static unsigned
194 hex_dec_len(const char *src, unsigned srclen)
195 {
196 	return srclen >> 1;
197 }
198 
199 /*
200  * BASE64
201  */
202 
203 static const char _base64[] =
204 "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
205 
206 static const int8 b64lookup[128] = {
207 	-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
208 	-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
209 	-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 62, -1, -1, -1, 63,
210 	52, 53, 54, 55, 56, 57, 58, 59, 60, 61, -1, -1, -1, -1, -1, -1,
211 	-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
212 	15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, -1, -1, -1, -1, -1,
213 	-1, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40,
214 	41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, -1, -1, -1, -1, -1,
215 };
216 
217 static unsigned
218 pg_base64_encode(const char *src, unsigned len, char *dst)
219 {
220 	char	   *p,
221 			   *lend = dst + 76;
222 	const char *s,
223 			   *end = src + len;
224 	int			pos = 2;
225 	uint32		buf = 0;
226 
227 	s = src;
228 	p = dst;
229 
230 	while (s < end)
231 	{
232 		buf |= (unsigned char) *s << (pos << 3);
233 		pos--;
234 		s++;
235 
236 		/* write it out */
237 		if (pos < 0)
238 		{
239 			*p++ = _base64[(buf >> 18) & 0x3f];
240 			*p++ = _base64[(buf >> 12) & 0x3f];
241 			*p++ = _base64[(buf >> 6) & 0x3f];
242 			*p++ = _base64[buf & 0x3f];
243 
244 			pos = 2;
245 			buf = 0;
246 		}
247 		if (p >= lend)
248 		{
249 			*p++ = '\n';
250 			lend = p + 76;
251 		}
252 	}
253 	if (pos != 2)
254 	{
255 		*p++ = _base64[(buf >> 18) & 0x3f];
256 		*p++ = _base64[(buf >> 12) & 0x3f];
257 		*p++ = (pos == 0) ? _base64[(buf >> 6) & 0x3f] : '=';
258 		*p++ = '=';
259 	}
260 
261 	return p - dst;
262 }
263 
264 static unsigned
265 pg_base64_decode(const char *src, unsigned len, char *dst)
266 {
267 	const char *srcend = src + len,
268 			   *s = src;
269 	char	   *p = dst;
270 	char		c;
271 	int			b = 0;
272 	uint32		buf = 0;
273 	int			pos = 0,
274 				end = 0;
275 
276 	while (s < srcend)
277 	{
278 		c = *s++;
279 
280 		if (c == ' ' || c == '\t' || c == '\n' || c == '\r')
281 			continue;
282 
283 		if (c == '=')
284 		{
285 			/* end sequence */
286 			if (!end)
287 			{
288 				if (pos == 2)
289 					end = 1;
290 				else if (pos == 3)
291 					end = 2;
292 				else
293 					ereport(ERROR,
294 							(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
295 							 errmsg("unexpected \"=\" while decoding base64 sequence")));
296 			}
297 			b = 0;
298 		}
299 		else
300 		{
301 			b = -1;
302 			if (c > 0 && c < 127)
303 				b = b64lookup[(unsigned char) c];
304 			if (b < 0)
305 				ereport(ERROR,
306 						(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
307 						 errmsg("invalid symbol \"%c\" while decoding base64 sequence", (int) c)));
308 		}
309 		/* add it to buffer */
310 		buf = (buf << 6) + b;
311 		pos++;
312 		if (pos == 4)
313 		{
314 			*p++ = (buf >> 16) & 255;
315 			if (end == 0 || end > 1)
316 				*p++ = (buf >> 8) & 255;
317 			if (end == 0 || end > 2)
318 				*p++ = buf & 255;
319 			buf = 0;
320 			pos = 0;
321 		}
322 	}
323 
324 	if (pos != 0)
325 		ereport(ERROR,
326 				(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
327 				 errmsg("invalid base64 end sequence"),
328 				 errhint("Input data is missing padding, is truncated, or is otherwise corrupted.")));
329 
330 	return p - dst;
331 }
332 
333 
334 static unsigned
335 pg_base64_enc_len(const char *src, unsigned srclen)
336 {
337 	/* 3 bytes will be converted to 4, linefeed after 76 chars */
338 	return (srclen + 2) * 4 / 3 + srclen / (76 * 3 / 4);
339 }
340 
341 static unsigned
342 pg_base64_dec_len(const char *src, unsigned srclen)
343 {
344 	return (srclen * 3) >> 2;
345 }
346 
347 /*
348  * Escape
349  * Minimally escape bytea to text.
350  * De-escape text to bytea.
351  *
352  * We must escape zero bytes and high-bit-set bytes to avoid generating
353  * text that might be invalid in the current encoding, or that might
354  * change to something else if passed through an encoding conversion
355  * (leading to failing to de-escape to the original bytea value).
356  * Also of course backslash itself has to be escaped.
357  *
358  * De-escaping processes \\ and any \### octal
359  */
360 
361 #define VAL(CH)			((CH) - '0')
362 #define DIG(VAL)		((VAL) + '0')
363 
364 static unsigned
365 esc_encode(const char *src, unsigned srclen, char *dst)
366 {
367 	const char *end = src + srclen;
368 	char	   *rp = dst;
369 	int			len = 0;
370 
371 	while (src < end)
372 	{
373 		unsigned char c = (unsigned char) *src;
374 
375 		if (c == '\0' || IS_HIGHBIT_SET(c))
376 		{
377 			rp[0] = '\\';
378 			rp[1] = DIG(c >> 6);
379 			rp[2] = DIG((c >> 3) & 7);
380 			rp[3] = DIG(c & 7);
381 			rp += 4;
382 			len += 4;
383 		}
384 		else if (c == '\\')
385 		{
386 			rp[0] = '\\';
387 			rp[1] = '\\';
388 			rp += 2;
389 			len += 2;
390 		}
391 		else
392 		{
393 			*rp++ = c;
394 			len++;
395 		}
396 
397 		src++;
398 	}
399 
400 	return len;
401 }
402 
403 static unsigned
404 esc_decode(const char *src, unsigned srclen, char *dst)
405 {
406 	const char *end = src + srclen;
407 	char	   *rp = dst;
408 	int			len = 0;
409 
410 	while (src < end)
411 	{
412 		if (src[0] != '\\')
413 			*rp++ = *src++;
414 		else if (src + 3 < end &&
415 				 (src[1] >= '0' && src[1] <= '3') &&
416 				 (src[2] >= '0' && src[2] <= '7') &&
417 				 (src[3] >= '0' && src[3] <= '7'))
418 		{
419 			int			val;
420 
421 			val = VAL(src[1]);
422 			val <<= 3;
423 			val += VAL(src[2]);
424 			val <<= 3;
425 			*rp++ = val + VAL(src[3]);
426 			src += 4;
427 		}
428 		else if (src + 1 < end &&
429 				 (src[1] == '\\'))
430 		{
431 			*rp++ = '\\';
432 			src += 2;
433 		}
434 		else
435 		{
436 			/*
437 			 * One backslash, not followed by ### valid octal. Should never
438 			 * get here, since esc_dec_len does same check.
439 			 */
440 			ereport(ERROR,
441 					(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
442 					 errmsg("invalid input syntax for type %s", "bytea")));
443 		}
444 
445 		len++;
446 	}
447 
448 	return len;
449 }
450 
451 static unsigned
452 esc_enc_len(const char *src, unsigned srclen)
453 {
454 	const char *end = src + srclen;
455 	int			len = 0;
456 
457 	while (src < end)
458 	{
459 		if (*src == '\0' || IS_HIGHBIT_SET(*src))
460 			len += 4;
461 		else if (*src == '\\')
462 			len += 2;
463 		else
464 			len++;
465 
466 		src++;
467 	}
468 
469 	return len;
470 }
471 
472 static unsigned
473 esc_dec_len(const char *src, unsigned srclen)
474 {
475 	const char *end = src + srclen;
476 	int			len = 0;
477 
478 	while (src < end)
479 	{
480 		if (src[0] != '\\')
481 			src++;
482 		else if (src + 3 < end &&
483 				 (src[1] >= '0' && src[1] <= '3') &&
484 				 (src[2] >= '0' && src[2] <= '7') &&
485 				 (src[3] >= '0' && src[3] <= '7'))
486 		{
487 			/*
488 			 * backslash + valid octal
489 			 */
490 			src += 4;
491 		}
492 		else if (src + 1 < end &&
493 				 (src[1] == '\\'))
494 		{
495 			/*
496 			 * two backslashes = backslash
497 			 */
498 			src += 2;
499 		}
500 		else
501 		{
502 			/*
503 			 * one backslash, not followed by ### valid octal
504 			 */
505 			ereport(ERROR,
506 					(errcode(ERRCODE_INVALID_TEXT_REPRESENTATION),
507 					 errmsg("invalid input syntax for type %s", "bytea")));
508 		}
509 
510 		len++;
511 	}
512 	return len;
513 }
514 
515 /*
516  * Common
517  */
518 
519 static const struct
520 {
521 	const char *name;
522 	struct pg_encoding enc;
523 }			enclist[] =
524 
525 {
526 	{
527 		"hex",
528 		{
529 			hex_enc_len, hex_dec_len, hex_encode, hex_decode
530 		}
531 	},
532 	{
533 		"base64",
534 		{
535 			pg_base64_enc_len, pg_base64_dec_len, pg_base64_encode, pg_base64_decode
536 		}
537 	},
538 	{
539 		"escape",
540 		{
541 			esc_enc_len, esc_dec_len, esc_encode, esc_decode
542 		}
543 	},
544 	{
545 		NULL,
546 		{
547 			NULL, NULL, NULL, NULL
548 		}
549 	}
550 };
551 
552 static const struct pg_encoding *
553 pg_find_encoding(const char *name)
554 {
555 	int			i;
556 
557 	for (i = 0; enclist[i].name; i++)
558 		if (pg_strcasecmp(enclist[i].name, name) == 0)
559 			return &enclist[i].enc;
560 
561 	return NULL;
562 }
563