1 /*-
2  * Copyright 2021 Vsevolod Stakhov
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "lua_common.h"
18 #include "unix-std.h"
19 #include <zlib.h>
20 
21 #ifdef SYS_ZSTD
22 #  include "zstd.h"
23 #  include "zstd_errors.h"
24 #else
25 #  include "contrib/zstd/zstd.h"
26 #  include "contrib/zstd/error_public.h"
27 #endif
28 
29 /***
30  * @module rspamd_compress
31  * This module contains compression/decompression routines (zstd and zlib currently)
32  */
33 
34 /***
35  * @function zstd.compress_ctx()
36  * Creates new compression ctx
37  * @return {compress_ctx} new compress ctx
38  */
39 LUA_FUNCTION_DEF (zstd, compress_ctx);
40 
41 /***
42  * @function zstd.compress_ctx()
43  * Creates new compression ctx
44  * @return {compress_ctx} new compress ctx
45  */
46 LUA_FUNCTION_DEF (zstd, decompress_ctx);
47 
48 LUA_FUNCTION_DEF (zstd_compress, stream);
49 LUA_FUNCTION_DEF (zstd_compress, dtor);
50 
51 LUA_FUNCTION_DEF (zstd_decompress, stream);
52 LUA_FUNCTION_DEF (zstd_decompress, dtor);
53 
54 static const struct luaL_reg zstd_compress_lib_f[] = {
55 		LUA_INTERFACE_DEF (zstd, compress_ctx),
56 		LUA_INTERFACE_DEF (zstd, decompress_ctx),
57 		{NULL, NULL}
58 };
59 
60 static const struct luaL_reg zstd_compress_lib_m[] = {
61 		LUA_INTERFACE_DEF (zstd_compress, stream),
62 		{"__gc", lua_zstd_compress_dtor},
63 		{NULL, NULL}
64 };
65 
66 static const struct luaL_reg zstd_decompress_lib_m[] = {
67 		LUA_INTERFACE_DEF (zstd_decompress, stream),
68 		{"__gc", lua_zstd_decompress_dtor},
69 		{NULL, NULL}
70 };
71 
72 static ZSTD_CStream *
lua_check_zstd_compress_ctx(lua_State * L,gint pos)73 lua_check_zstd_compress_ctx (lua_State *L, gint pos)
74 {
75 	void *ud = rspamd_lua_check_udata (L, pos, "rspamd{zstd_compress}");
76 	luaL_argcheck (L, ud != NULL, pos, "'zstd_compress' expected");
77 	return ud ? *(ZSTD_CStream **)ud : NULL;
78 }
79 
80 static ZSTD_DStream *
lua_check_zstd_decompress_ctx(lua_State * L,gint pos)81 lua_check_zstd_decompress_ctx (lua_State *L, gint pos)
82 {
83 	void *ud = rspamd_lua_check_udata (L, pos, "rspamd{zstd_decompress}");
84 	luaL_argcheck (L, ud != NULL, pos, "'zstd_decompress' expected");
85 	return ud ? *(ZSTD_DStream **)ud : NULL;
86 }
87 
88 int
lua_zstd_push_error(lua_State * L,int err)89 lua_zstd_push_error (lua_State *L, int err)
90 {
91 	lua_pushnil (L);
92 	lua_pushfstring (L, "zstd error %d (%s)", err, ZSTD_getErrorString (err));
93 
94 	return 2;
95 }
96 
97 gint
lua_compress_zstd_compress(lua_State * L)98 lua_compress_zstd_compress (lua_State *L)
99 {
100 	LUA_TRACE_POINT;
101 	struct rspamd_lua_text *t = NULL, *res;
102 	gsize sz, r;
103 	gint comp_level = 1;
104 
105 	t = lua_check_text_or_string (L,1);
106 
107 	if (t == NULL || t->start == NULL) {
108 		return luaL_error (L, "invalid arguments");
109 	}
110 
111 	if (lua_type (L, 2) == LUA_TNUMBER) {
112 		comp_level = lua_tointeger (L, 2);
113 	}
114 
115 	sz = ZSTD_compressBound (t->len);
116 
117 	if (ZSTD_isError (sz)) {
118 		msg_err ("cannot compress data: %s", ZSTD_getErrorName (sz));
119 		lua_pushnil (L);
120 
121 		return 1;
122 	}
123 
124 	res = lua_newuserdata (L, sizeof (*res));
125 	res->start = g_malloc (sz);
126 	res->flags = RSPAMD_TEXT_FLAG_OWN;
127 	rspamd_lua_setclass (L, "rspamd{text}", -1);
128 	r = ZSTD_compress ((void *)res->start, sz, t->start, t->len, comp_level);
129 
130 	if (ZSTD_isError (r)) {
131 		msg_err ("cannot compress data: %s", ZSTD_getErrorName (r));
132 		lua_pop (L, 1); /* Text will be freed here */
133 		lua_pushnil (L);
134 
135 		return 1;
136 	}
137 
138 	res->len = r;
139 
140 	return 1;
141 }
142 
143 gint
lua_compress_zstd_decompress(lua_State * L)144 lua_compress_zstd_decompress (lua_State *L)
145 {
146 	LUA_TRACE_POINT;
147 	struct rspamd_lua_text *t = NULL, *res;
148 	gsize outlen, r;
149 	ZSTD_DStream *zstream;
150 	ZSTD_inBuffer zin;
151 	ZSTD_outBuffer zout;
152 	gchar *out;
153 
154 	t = lua_check_text_or_string (L,1);
155 
156 	if (t == NULL || t->start == NULL) {
157 		return luaL_error (L, "invalid arguments");
158 	}
159 
160 	zstream = ZSTD_createDStream ();
161 	ZSTD_initDStream (zstream);
162 
163 	zin.pos = 0;
164 	zin.src = t->start;
165 	zin.size = t->len;
166 
167 	if ((outlen = ZSTD_getDecompressedSize (zin.src, zin.size)) == 0) {
168 		outlen = ZSTD_DStreamOutSize ();
169 	}
170 
171 	out = g_malloc (outlen);
172 
173 	zout.dst = out;
174 	zout.pos = 0;
175 	zout.size = outlen;
176 
177 	while (zin.pos < zin.size) {
178 		r = ZSTD_decompressStream (zstream, &zout, &zin);
179 
180 		if (ZSTD_isError (r)) {
181 			msg_err ("cannot decompress data: %s", ZSTD_getErrorName (r));
182 			ZSTD_freeDStream (zstream);
183 			g_free (out);
184 			lua_pushstring (L, ZSTD_getErrorName (r));
185 			lua_pushnil (L);
186 
187 			return 2;
188 		}
189 
190 		if (zin.pos < zin.size && zout.pos == zout.size) {
191 			/* We need to extend output buffer */
192 			zout.size = zout.size * 2;
193 			out = g_realloc (zout.dst, zout.size);
194 			zout.dst = out;
195 		}
196 	}
197 
198 	ZSTD_freeDStream (zstream);
199 	lua_pushnil (L); /* Error */
200 	res = lua_newuserdata (L, sizeof (*res));
201 	res->start = out;
202 	res->flags = RSPAMD_TEXT_FLAG_OWN;
203 	rspamd_lua_setclass (L, "rspamd{text}", -1);
204 	res->len = zout.pos;
205 
206 	return 2;
207 }
208 
209 gint
lua_compress_zlib_decompress(lua_State * L,bool is_gzip)210 lua_compress_zlib_decompress (lua_State *L, bool is_gzip)
211 {
212 	LUA_TRACE_POINT;
213 	struct rspamd_lua_text *t = NULL, *res;
214 	gsize sz;
215 	z_stream strm;
216 	gint rc;
217 	guchar *p;
218 	gsize remain;
219 	gssize size_limit = -1;
220 
221 	int windowBits = is_gzip ? (MAX_WBITS + 16) : (MAX_WBITS);
222 
223 	t = lua_check_text_or_string (L,1);
224 
225 	if (t == NULL || t->start == NULL) {
226 		return luaL_error (L, "invalid arguments");
227 	}
228 
229 	if (lua_type (L, 2) == LUA_TNUMBER) {
230 		size_limit = lua_tointeger (L, 2);
231 		if (size_limit <= 0) {
232 			return luaL_error (L, "invalid arguments (size_limit)");
233 		}
234 
235 		sz = MIN (t->len * 2, size_limit);
236 	}
237 	else {
238 		sz = t->len * 2;
239 	}
240 
241 	memset (&strm, 0, sizeof (strm));
242 	/* windowBits +16 to decode gzip, zlib 1.2.0.4+ */
243 
244 	/* Here are dragons to distinguish between raw deflate and zlib */
245 	if (windowBits == MAX_WBITS && t->len > 0) {
246 		if ((int)(unsigned char)((t->start[0] << 4)) != 0x80) {
247 			/* Assume raw deflate */
248 			windowBits = -windowBits;
249 		}
250 	}
251 
252 	rc = inflateInit2 (&strm, windowBits);
253 
254 	if (rc != Z_OK) {
255 		return luaL_error (L, "cannot init zlib");
256 	}
257 
258 	strm.avail_in = t->len;
259 	strm.next_in = (guchar *)t->start;
260 
261 	res = lua_newuserdata (L, sizeof (*res));
262 	res->start = g_malloc (sz);
263 	res->flags = RSPAMD_TEXT_FLAG_OWN;
264 	rspamd_lua_setclass (L, "rspamd{text}", -1);
265 
266 	p = (guchar *)res->start;
267 	remain = sz;
268 
269 	while (strm.avail_in != 0) {
270 		strm.avail_out = remain;
271 		strm.next_out = p;
272 
273 		rc = inflate (&strm, Z_NO_FLUSH);
274 
275 		if (rc != Z_OK && rc != Z_BUF_ERROR) {
276 			if (rc == Z_STREAM_END) {
277 				break;
278 			}
279 			else {
280 				msg_err ("cannot decompress data: %s (last error: %s)",
281 						zError (rc), strm.msg);
282 				lua_pop (L, 1); /* Text will be freed here */
283 				lua_pushnil (L);
284 				inflateEnd (&strm);
285 
286 				return 1;
287 			}
288 		}
289 
290 		res->len = strm.total_out;
291 
292 		if (strm.avail_out == 0 && strm.avail_in != 0) {
293 
294 			if (size_limit > 0 || res->len >= G_MAXUINT32 / 2) {
295 				if (res->len > size_limit || res->len >= G_MAXUINT32 / 2) {
296 					lua_pop (L, 1); /* Text will be freed here */
297 					lua_pushnil (L);
298 					inflateEnd (&strm);
299 
300 					return 1;
301 				}
302 			}
303 
304 			/* Need to allocate more */
305 			remain = res->len;
306 			res->start = g_realloc ((gpointer)res->start, res->len * 2);
307 			sz = res->len * 2;
308 			p = (guchar *)res->start + remain;
309 			remain = sz - remain;
310 		}
311 	}
312 
313 	inflateEnd (&strm);
314 	res->len = strm.total_out;
315 
316 	return 1;
317 }
318 
319 gint
lua_compress_zlib_compress(lua_State * L)320 lua_compress_zlib_compress (lua_State *L)
321 {
322 	LUA_TRACE_POINT;
323 	struct rspamd_lua_text *t = NULL, *res;
324 	gsize sz;
325 	z_stream strm;
326 	gint rc, comp_level = Z_DEFAULT_COMPRESSION;
327 	guchar *p;
328 	gsize remain;
329 
330 	t = lua_check_text_or_string (L,1);
331 
332 	if (t == NULL || t->start == NULL) {
333 		return luaL_error (L, "invalid arguments");
334 	}
335 
336 	if (lua_isnumber (L, 2)) {
337 		comp_level = lua_tointeger (L, 2);
338 
339 		if (comp_level > Z_BEST_COMPRESSION || comp_level < Z_BEST_SPEED) {
340 			return luaL_error (L, "invalid arguments: compression level must be between %d and %d",
341 					Z_BEST_SPEED, Z_BEST_COMPRESSION);
342 		}
343 	}
344 
345 
346 	memset (&strm, 0, sizeof (strm));
347 	rc = deflateInit2 (&strm, comp_level, Z_DEFLATED,
348 			MAX_WBITS + 16, MAX_MEM_LEVEL - 1, Z_DEFAULT_STRATEGY);
349 
350 	if (rc != Z_OK) {
351 		return luaL_error (L, "cannot init zlib: %s", zError (rc));
352 	}
353 
354 	sz = deflateBound (&strm, t->len);
355 
356 	strm.avail_in = t->len;
357 	strm.next_in = (guchar *) t->start;
358 
359 	res = lua_newuserdata (L, sizeof (*res));
360 	res->start = g_malloc (sz);
361 	res->flags = RSPAMD_TEXT_FLAG_OWN;
362 	rspamd_lua_setclass (L, "rspamd{text}", -1);
363 
364 	p = (guchar *) res->start;
365 	remain = sz;
366 
367 	while (strm.avail_in != 0) {
368 		strm.avail_out = remain;
369 		strm.next_out = p;
370 
371 		rc = deflate (&strm, Z_FINISH);
372 
373 		if (rc != Z_OK && rc != Z_BUF_ERROR) {
374 			if (rc == Z_STREAM_END) {
375 				break;
376 			}
377 			else {
378 				msg_err ("cannot compress data: %s (last error: %s)",
379 						zError (rc), strm.msg);
380 				lua_pop (L, 1); /* Text will be freed here */
381 				lua_pushnil (L);
382 				deflateEnd (&strm);
383 
384 				return 1;
385 			}
386 		}
387 
388 		res->len = strm.total_out;
389 
390 		if (strm.avail_out == 0 && strm.avail_in != 0) {
391 			/* Need to allocate more */
392 			remain = res->len;
393 			res->start = g_realloc ((gpointer) res->start, strm.avail_in + sz);
394 			sz = strm.avail_in + sz;
395 			p = (guchar *) res->start + remain;
396 			remain = sz - remain;
397 		}
398 	}
399 
400 	deflateEnd (&strm);
401 	res->len = strm.total_out;
402 
403 	return 1;
404 }
405 
406 /* Stream API interface for Zstd: both compression and decompression */
407 
408 /* Operations allowed by zstd stream methods */
409 static const char *const zstd_stream_op[] = {
410 		"continue",
411 		"flush",
412 		"end",
413 		NULL
414 };
415 
416 static gint
lua_zstd_compress_ctx(lua_State * L)417 lua_zstd_compress_ctx (lua_State *L)
418 {
419 	ZSTD_CCtx *ctx, **pctx;
420 
421 	pctx = lua_newuserdata (L, sizeof (*pctx));
422 	ctx = ZSTD_createCCtx ();
423 
424 	if (!ctx) {
425 		return luaL_error (L, "context create failed");
426 	}
427 
428 	*pctx = ctx;
429 	rspamd_lua_setclass (L, "rspamd{zstd_compress}", -1);
430 	return 1;
431 }
432 
433 static gint
lua_zstd_compress_dtor(lua_State * L)434 lua_zstd_compress_dtor (lua_State *L)
435 {
436 	ZSTD_CCtx *ctx = lua_check_zstd_compress_ctx (L, 1);
437 
438 	if (ctx) {
439 		ZSTD_freeCCtx (ctx);
440 	}
441 
442 	return 0;
443 }
444 
445 static gint
lua_zstd_compress_reset(lua_State * L)446 lua_zstd_compress_reset (lua_State *L)
447 {
448 	ZSTD_CCtx *ctx = lua_check_zstd_compress_ctx (L, 1);
449 
450 	if (ctx) {
451 		ZSTD_CCtx_reset (ctx, ZSTD_reset_session_and_parameters);
452 	}
453 	else {
454 		return luaL_error (L, "invalid arguments");
455 	}
456 
457 	return 0;
458 }
459 
460 static gint
lua_zstd_compress_stream(lua_State * L)461 lua_zstd_compress_stream (lua_State *L)
462 {
463 	ZSTD_CStream *ctx = lua_check_zstd_compress_ctx (L, 1);
464 	struct rspamd_lua_text *t = lua_check_text_or_string (L, 2);
465 	int op = luaL_checkoption (L, 3, zstd_stream_op[0], zstd_stream_op);
466 	int err = 0;
467 	ZSTD_inBuffer inb;
468 	ZSTD_outBuffer onb;
469 
470 	if (ctx && t) {
471 		gsize dlen = 0;
472 
473 		inb.size = t->len;
474 		inb.pos = 0;
475 		inb.src = (const void*)t->start;
476 
477 		onb.pos = 0;
478 		onb.size = ZSTD_CStreamInSize (); /* Initial guess */
479 		onb.dst = NULL;
480 
481 		for (;;) {
482 			if ((onb.dst = g_realloc (onb.dst, onb.size)) == NULL) {
483 				return lua_zstd_push_error (L, ZSTD_error_memory_allocation);
484 			}
485 
486 			dlen = onb.size;
487 
488 			int res = ZSTD_compressStream2 (ctx, &onb, &inb, op);
489 
490 			if (res == 0) {
491 				/* All done */
492 				break;
493 			}
494 
495 			if ((err = ZSTD_getErrorCode (res))) {
496 				break;
497 			}
498 
499 			onb.size *= 2;
500 			res += dlen; /* Hint returned by compression routine */
501 
502 			/* Either double the buffer, or use the hint provided */
503 			if (onb.size < res) {
504 				onb.size = res;
505 			}
506 		}
507 	}
508 	else {
509 		return luaL_error (L, "invalid arguments");
510 	}
511 
512 	if (err) {
513 		return lua_zstd_push_error (L, err);
514 	}
515 
516 	lua_new_text (L, onb.dst, onb.pos, TRUE);
517 
518 	return 1;
519 }
520 
521 static gint
lua_zstd_decompress_dtor(lua_State * L)522 lua_zstd_decompress_dtor (lua_State *L)
523 {
524 	ZSTD_DStream *ctx = lua_check_zstd_decompress_ctx (L, 1);
525 
526 	if (ctx) {
527 		ZSTD_freeDStream (ctx);
528 	}
529 
530 	return 0;
531 }
532 
533 
534 static gint
lua_zstd_decompress_ctx(lua_State * L)535 lua_zstd_decompress_ctx (lua_State *L)
536 {
537 	ZSTD_DStream *ctx, **pctx;
538 
539 	pctx = lua_newuserdata (L, sizeof (*pctx));
540 	ctx = ZSTD_createDStream ();
541 
542 	if (!ctx) {
543 		return luaL_error (L, "context create failed");
544 	}
545 
546 	*pctx = ctx;
547 	rspamd_lua_setclass (L, "rspamd{zstd_decompress}", -1);
548 	return 1;
549 }
550 
551 static gint
lua_zstd_decompress_stream(lua_State * L)552 lua_zstd_decompress_stream (lua_State *L)
553 {
554 	ZSTD_DStream *ctx = lua_check_zstd_decompress_ctx (L, 1);
555 	struct rspamd_lua_text *t = lua_check_text_or_string (L, 2);
556 	int err = 0;
557 	ZSTD_inBuffer inb;
558 	ZSTD_outBuffer onb;
559 
560 	if (ctx && t) {
561 		gsize dlen = 0;
562 
563 		if (t->len == 0) {
564 			return lua_zstd_push_error (L, ZSTD_error_init_missing);
565 		}
566 
567 		inb.size = t->len;
568 		inb.pos = 0;
569 		inb.src = (const void*)t->start;
570 
571 		onb.pos = 0;
572 		onb.size = ZSTD_DStreamInSize (); /* Initial guess */
573 		onb.dst = NULL;
574 
575 		for (;;) {
576 			if ((onb.dst = g_realloc (onb.dst, onb.size)) == NULL) {
577 				return lua_zstd_push_error (L, ZSTD_error_memory_allocation);
578 			}
579 
580 			dlen = onb.size;
581 
582 			int res = ZSTD_decompressStream (ctx, &onb, &inb);
583 
584 			if (res == 0) {
585 				/* All done */
586 				break;
587 			}
588 
589 			if ((err = ZSTD_getErrorCode (res))) {
590 				break;
591 			}
592 
593 			onb.size *= 2;
594 			res += dlen; /* Hint returned by compression routine */
595 
596 			/* Either double the buffer, or use the hint provided */
597 			if (onb.size < res) {
598 				onb.size = res;
599 			}
600 		}
601 	}
602 	else {
603 		return luaL_error (L, "invalid arguments");
604 	}
605 
606 	if (err) {
607 		return lua_zstd_push_error (L, err);
608 	}
609 
610 	lua_new_text (L, onb.dst, onb.pos, TRUE);
611 
612 	return 1;
613 }
614 
615 static gint
lua_load_zstd(lua_State * L)616 lua_load_zstd (lua_State * L)
617 {
618 	lua_newtable (L);
619 	luaL_register (L, NULL, zstd_compress_lib_f);
620 
621 	return 1;
622 }
623 
624 void
luaopen_compress(lua_State * L)625 luaopen_compress (lua_State *L)
626 {
627 	rspamd_lua_new_class (L, "rspamd{zstd_compress}", zstd_compress_lib_m);
628 	rspamd_lua_new_class (L, "rspamd{zstd_decompress}", zstd_decompress_lib_m);
629 	lua_pop (L, 2);
630 
631 	rspamd_lua_add_preload (L, "rspamd_zstd", lua_load_zstd);
632 }
633