1 /*-
2 * Copyright 2021 Vsevolod Stakhov
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "lua_common.h"
18 #include "unix-std.h"
19 #include <zlib.h>
20
21 #ifdef SYS_ZSTD
22 # include "zstd.h"
23 # include "zstd_errors.h"
24 #else
25 # include "contrib/zstd/zstd.h"
26 # include "contrib/zstd/error_public.h"
27 #endif
28
29 /***
30 * @module rspamd_compress
31 * This module contains compression/decompression routines (zstd and zlib currently)
32 */
33
34 /***
35 * @function zstd.compress_ctx()
36 * Creates new compression ctx
37 * @return {compress_ctx} new compress ctx
38 */
39 LUA_FUNCTION_DEF (zstd, compress_ctx);
40
41 /***
42 * @function zstd.compress_ctx()
43 * Creates new compression ctx
44 * @return {compress_ctx} new compress ctx
45 */
46 LUA_FUNCTION_DEF (zstd, decompress_ctx);
47
48 LUA_FUNCTION_DEF (zstd_compress, stream);
49 LUA_FUNCTION_DEF (zstd_compress, dtor);
50
51 LUA_FUNCTION_DEF (zstd_decompress, stream);
52 LUA_FUNCTION_DEF (zstd_decompress, dtor);
53
54 static const struct luaL_reg zstd_compress_lib_f[] = {
55 LUA_INTERFACE_DEF (zstd, compress_ctx),
56 LUA_INTERFACE_DEF (zstd, decompress_ctx),
57 {NULL, NULL}
58 };
59
60 static const struct luaL_reg zstd_compress_lib_m[] = {
61 LUA_INTERFACE_DEF (zstd_compress, stream),
62 {"__gc", lua_zstd_compress_dtor},
63 {NULL, NULL}
64 };
65
66 static const struct luaL_reg zstd_decompress_lib_m[] = {
67 LUA_INTERFACE_DEF (zstd_decompress, stream),
68 {"__gc", lua_zstd_decompress_dtor},
69 {NULL, NULL}
70 };
71
72 static ZSTD_CStream *
lua_check_zstd_compress_ctx(lua_State * L,gint pos)73 lua_check_zstd_compress_ctx (lua_State *L, gint pos)
74 {
75 void *ud = rspamd_lua_check_udata (L, pos, "rspamd{zstd_compress}");
76 luaL_argcheck (L, ud != NULL, pos, "'zstd_compress' expected");
77 return ud ? *(ZSTD_CStream **)ud : NULL;
78 }
79
80 static ZSTD_DStream *
lua_check_zstd_decompress_ctx(lua_State * L,gint pos)81 lua_check_zstd_decompress_ctx (lua_State *L, gint pos)
82 {
83 void *ud = rspamd_lua_check_udata (L, pos, "rspamd{zstd_decompress}");
84 luaL_argcheck (L, ud != NULL, pos, "'zstd_decompress' expected");
85 return ud ? *(ZSTD_DStream **)ud : NULL;
86 }
87
88 int
lua_zstd_push_error(lua_State * L,int err)89 lua_zstd_push_error (lua_State *L, int err)
90 {
91 lua_pushnil (L);
92 lua_pushfstring (L, "zstd error %d (%s)", err, ZSTD_getErrorString (err));
93
94 return 2;
95 }
96
97 gint
lua_compress_zstd_compress(lua_State * L)98 lua_compress_zstd_compress (lua_State *L)
99 {
100 LUA_TRACE_POINT;
101 struct rspamd_lua_text *t = NULL, *res;
102 gsize sz, r;
103 gint comp_level = 1;
104
105 t = lua_check_text_or_string (L,1);
106
107 if (t == NULL || t->start == NULL) {
108 return luaL_error (L, "invalid arguments");
109 }
110
111 if (lua_type (L, 2) == LUA_TNUMBER) {
112 comp_level = lua_tointeger (L, 2);
113 }
114
115 sz = ZSTD_compressBound (t->len);
116
117 if (ZSTD_isError (sz)) {
118 msg_err ("cannot compress data: %s", ZSTD_getErrorName (sz));
119 lua_pushnil (L);
120
121 return 1;
122 }
123
124 res = lua_newuserdata (L, sizeof (*res));
125 res->start = g_malloc (sz);
126 res->flags = RSPAMD_TEXT_FLAG_OWN;
127 rspamd_lua_setclass (L, "rspamd{text}", -1);
128 r = ZSTD_compress ((void *)res->start, sz, t->start, t->len, comp_level);
129
130 if (ZSTD_isError (r)) {
131 msg_err ("cannot compress data: %s", ZSTD_getErrorName (r));
132 lua_pop (L, 1); /* Text will be freed here */
133 lua_pushnil (L);
134
135 return 1;
136 }
137
138 res->len = r;
139
140 return 1;
141 }
142
143 gint
lua_compress_zstd_decompress(lua_State * L)144 lua_compress_zstd_decompress (lua_State *L)
145 {
146 LUA_TRACE_POINT;
147 struct rspamd_lua_text *t = NULL, *res;
148 gsize outlen, r;
149 ZSTD_DStream *zstream;
150 ZSTD_inBuffer zin;
151 ZSTD_outBuffer zout;
152 gchar *out;
153
154 t = lua_check_text_or_string (L,1);
155
156 if (t == NULL || t->start == NULL) {
157 return luaL_error (L, "invalid arguments");
158 }
159
160 zstream = ZSTD_createDStream ();
161 ZSTD_initDStream (zstream);
162
163 zin.pos = 0;
164 zin.src = t->start;
165 zin.size = t->len;
166
167 if ((outlen = ZSTD_getDecompressedSize (zin.src, zin.size)) == 0) {
168 outlen = ZSTD_DStreamOutSize ();
169 }
170
171 out = g_malloc (outlen);
172
173 zout.dst = out;
174 zout.pos = 0;
175 zout.size = outlen;
176
177 while (zin.pos < zin.size) {
178 r = ZSTD_decompressStream (zstream, &zout, &zin);
179
180 if (ZSTD_isError (r)) {
181 msg_err ("cannot decompress data: %s", ZSTD_getErrorName (r));
182 ZSTD_freeDStream (zstream);
183 g_free (out);
184 lua_pushstring (L, ZSTD_getErrorName (r));
185 lua_pushnil (L);
186
187 return 2;
188 }
189
190 if (zin.pos < zin.size && zout.pos == zout.size) {
191 /* We need to extend output buffer */
192 zout.size = zout.size * 2;
193 out = g_realloc (zout.dst, zout.size);
194 zout.dst = out;
195 }
196 }
197
198 ZSTD_freeDStream (zstream);
199 lua_pushnil (L); /* Error */
200 res = lua_newuserdata (L, sizeof (*res));
201 res->start = out;
202 res->flags = RSPAMD_TEXT_FLAG_OWN;
203 rspamd_lua_setclass (L, "rspamd{text}", -1);
204 res->len = zout.pos;
205
206 return 2;
207 }
208
209 gint
lua_compress_zlib_decompress(lua_State * L,bool is_gzip)210 lua_compress_zlib_decompress (lua_State *L, bool is_gzip)
211 {
212 LUA_TRACE_POINT;
213 struct rspamd_lua_text *t = NULL, *res;
214 gsize sz;
215 z_stream strm;
216 gint rc;
217 guchar *p;
218 gsize remain;
219 gssize size_limit = -1;
220
221 int windowBits = is_gzip ? (MAX_WBITS + 16) : (MAX_WBITS);
222
223 t = lua_check_text_or_string (L,1);
224
225 if (t == NULL || t->start == NULL) {
226 return luaL_error (L, "invalid arguments");
227 }
228
229 if (lua_type (L, 2) == LUA_TNUMBER) {
230 size_limit = lua_tointeger (L, 2);
231 if (size_limit <= 0) {
232 return luaL_error (L, "invalid arguments (size_limit)");
233 }
234
235 sz = MIN (t->len * 2, size_limit);
236 }
237 else {
238 sz = t->len * 2;
239 }
240
241 memset (&strm, 0, sizeof (strm));
242 /* windowBits +16 to decode gzip, zlib 1.2.0.4+ */
243
244 /* Here are dragons to distinguish between raw deflate and zlib */
245 if (windowBits == MAX_WBITS && t->len > 0) {
246 if ((int)(unsigned char)((t->start[0] << 4)) != 0x80) {
247 /* Assume raw deflate */
248 windowBits = -windowBits;
249 }
250 }
251
252 rc = inflateInit2 (&strm, windowBits);
253
254 if (rc != Z_OK) {
255 return luaL_error (L, "cannot init zlib");
256 }
257
258 strm.avail_in = t->len;
259 strm.next_in = (guchar *)t->start;
260
261 res = lua_newuserdata (L, sizeof (*res));
262 res->start = g_malloc (sz);
263 res->flags = RSPAMD_TEXT_FLAG_OWN;
264 rspamd_lua_setclass (L, "rspamd{text}", -1);
265
266 p = (guchar *)res->start;
267 remain = sz;
268
269 while (strm.avail_in != 0) {
270 strm.avail_out = remain;
271 strm.next_out = p;
272
273 rc = inflate (&strm, Z_NO_FLUSH);
274
275 if (rc != Z_OK && rc != Z_BUF_ERROR) {
276 if (rc == Z_STREAM_END) {
277 break;
278 }
279 else {
280 msg_err ("cannot decompress data: %s (last error: %s)",
281 zError (rc), strm.msg);
282 lua_pop (L, 1); /* Text will be freed here */
283 lua_pushnil (L);
284 inflateEnd (&strm);
285
286 return 1;
287 }
288 }
289
290 res->len = strm.total_out;
291
292 if (strm.avail_out == 0 && strm.avail_in != 0) {
293
294 if (size_limit > 0 || res->len >= G_MAXUINT32 / 2) {
295 if (res->len > size_limit || res->len >= G_MAXUINT32 / 2) {
296 lua_pop (L, 1); /* Text will be freed here */
297 lua_pushnil (L);
298 inflateEnd (&strm);
299
300 return 1;
301 }
302 }
303
304 /* Need to allocate more */
305 remain = res->len;
306 res->start = g_realloc ((gpointer)res->start, res->len * 2);
307 sz = res->len * 2;
308 p = (guchar *)res->start + remain;
309 remain = sz - remain;
310 }
311 }
312
313 inflateEnd (&strm);
314 res->len = strm.total_out;
315
316 return 1;
317 }
318
319 gint
lua_compress_zlib_compress(lua_State * L)320 lua_compress_zlib_compress (lua_State *L)
321 {
322 LUA_TRACE_POINT;
323 struct rspamd_lua_text *t = NULL, *res;
324 gsize sz;
325 z_stream strm;
326 gint rc, comp_level = Z_DEFAULT_COMPRESSION;
327 guchar *p;
328 gsize remain;
329
330 t = lua_check_text_or_string (L,1);
331
332 if (t == NULL || t->start == NULL) {
333 return luaL_error (L, "invalid arguments");
334 }
335
336 if (lua_isnumber (L, 2)) {
337 comp_level = lua_tointeger (L, 2);
338
339 if (comp_level > Z_BEST_COMPRESSION || comp_level < Z_BEST_SPEED) {
340 return luaL_error (L, "invalid arguments: compression level must be between %d and %d",
341 Z_BEST_SPEED, Z_BEST_COMPRESSION);
342 }
343 }
344
345
346 memset (&strm, 0, sizeof (strm));
347 rc = deflateInit2 (&strm, comp_level, Z_DEFLATED,
348 MAX_WBITS + 16, MAX_MEM_LEVEL - 1, Z_DEFAULT_STRATEGY);
349
350 if (rc != Z_OK) {
351 return luaL_error (L, "cannot init zlib: %s", zError (rc));
352 }
353
354 sz = deflateBound (&strm, t->len);
355
356 strm.avail_in = t->len;
357 strm.next_in = (guchar *) t->start;
358
359 res = lua_newuserdata (L, sizeof (*res));
360 res->start = g_malloc (sz);
361 res->flags = RSPAMD_TEXT_FLAG_OWN;
362 rspamd_lua_setclass (L, "rspamd{text}", -1);
363
364 p = (guchar *) res->start;
365 remain = sz;
366
367 while (strm.avail_in != 0) {
368 strm.avail_out = remain;
369 strm.next_out = p;
370
371 rc = deflate (&strm, Z_FINISH);
372
373 if (rc != Z_OK && rc != Z_BUF_ERROR) {
374 if (rc == Z_STREAM_END) {
375 break;
376 }
377 else {
378 msg_err ("cannot compress data: %s (last error: %s)",
379 zError (rc), strm.msg);
380 lua_pop (L, 1); /* Text will be freed here */
381 lua_pushnil (L);
382 deflateEnd (&strm);
383
384 return 1;
385 }
386 }
387
388 res->len = strm.total_out;
389
390 if (strm.avail_out == 0 && strm.avail_in != 0) {
391 /* Need to allocate more */
392 remain = res->len;
393 res->start = g_realloc ((gpointer) res->start, strm.avail_in + sz);
394 sz = strm.avail_in + sz;
395 p = (guchar *) res->start + remain;
396 remain = sz - remain;
397 }
398 }
399
400 deflateEnd (&strm);
401 res->len = strm.total_out;
402
403 return 1;
404 }
405
406 /* Stream API interface for Zstd: both compression and decompression */
407
408 /* Operations allowed by zstd stream methods */
409 static const char *const zstd_stream_op[] = {
410 "continue",
411 "flush",
412 "end",
413 NULL
414 };
415
416 static gint
lua_zstd_compress_ctx(lua_State * L)417 lua_zstd_compress_ctx (lua_State *L)
418 {
419 ZSTD_CCtx *ctx, **pctx;
420
421 pctx = lua_newuserdata (L, sizeof (*pctx));
422 ctx = ZSTD_createCCtx ();
423
424 if (!ctx) {
425 return luaL_error (L, "context create failed");
426 }
427
428 *pctx = ctx;
429 rspamd_lua_setclass (L, "rspamd{zstd_compress}", -1);
430 return 1;
431 }
432
433 static gint
lua_zstd_compress_dtor(lua_State * L)434 lua_zstd_compress_dtor (lua_State *L)
435 {
436 ZSTD_CCtx *ctx = lua_check_zstd_compress_ctx (L, 1);
437
438 if (ctx) {
439 ZSTD_freeCCtx (ctx);
440 }
441
442 return 0;
443 }
444
445 static gint
lua_zstd_compress_reset(lua_State * L)446 lua_zstd_compress_reset (lua_State *L)
447 {
448 ZSTD_CCtx *ctx = lua_check_zstd_compress_ctx (L, 1);
449
450 if (ctx) {
451 ZSTD_CCtx_reset (ctx, ZSTD_reset_session_and_parameters);
452 }
453 else {
454 return luaL_error (L, "invalid arguments");
455 }
456
457 return 0;
458 }
459
460 static gint
lua_zstd_compress_stream(lua_State * L)461 lua_zstd_compress_stream (lua_State *L)
462 {
463 ZSTD_CStream *ctx = lua_check_zstd_compress_ctx (L, 1);
464 struct rspamd_lua_text *t = lua_check_text_or_string (L, 2);
465 int op = luaL_checkoption (L, 3, zstd_stream_op[0], zstd_stream_op);
466 int err = 0;
467 ZSTD_inBuffer inb;
468 ZSTD_outBuffer onb;
469
470 if (ctx && t) {
471 gsize dlen = 0;
472
473 inb.size = t->len;
474 inb.pos = 0;
475 inb.src = (const void*)t->start;
476
477 onb.pos = 0;
478 onb.size = ZSTD_CStreamInSize (); /* Initial guess */
479 onb.dst = NULL;
480
481 for (;;) {
482 if ((onb.dst = g_realloc (onb.dst, onb.size)) == NULL) {
483 return lua_zstd_push_error (L, ZSTD_error_memory_allocation);
484 }
485
486 dlen = onb.size;
487
488 int res = ZSTD_compressStream2 (ctx, &onb, &inb, op);
489
490 if (res == 0) {
491 /* All done */
492 break;
493 }
494
495 if ((err = ZSTD_getErrorCode (res))) {
496 break;
497 }
498
499 onb.size *= 2;
500 res += dlen; /* Hint returned by compression routine */
501
502 /* Either double the buffer, or use the hint provided */
503 if (onb.size < res) {
504 onb.size = res;
505 }
506 }
507 }
508 else {
509 return luaL_error (L, "invalid arguments");
510 }
511
512 if (err) {
513 return lua_zstd_push_error (L, err);
514 }
515
516 lua_new_text (L, onb.dst, onb.pos, TRUE);
517
518 return 1;
519 }
520
521 static gint
lua_zstd_decompress_dtor(lua_State * L)522 lua_zstd_decompress_dtor (lua_State *L)
523 {
524 ZSTD_DStream *ctx = lua_check_zstd_decompress_ctx (L, 1);
525
526 if (ctx) {
527 ZSTD_freeDStream (ctx);
528 }
529
530 return 0;
531 }
532
533
534 static gint
lua_zstd_decompress_ctx(lua_State * L)535 lua_zstd_decompress_ctx (lua_State *L)
536 {
537 ZSTD_DStream *ctx, **pctx;
538
539 pctx = lua_newuserdata (L, sizeof (*pctx));
540 ctx = ZSTD_createDStream ();
541
542 if (!ctx) {
543 return luaL_error (L, "context create failed");
544 }
545
546 *pctx = ctx;
547 rspamd_lua_setclass (L, "rspamd{zstd_decompress}", -1);
548 return 1;
549 }
550
551 static gint
lua_zstd_decompress_stream(lua_State * L)552 lua_zstd_decompress_stream (lua_State *L)
553 {
554 ZSTD_DStream *ctx = lua_check_zstd_decompress_ctx (L, 1);
555 struct rspamd_lua_text *t = lua_check_text_or_string (L, 2);
556 int err = 0;
557 ZSTD_inBuffer inb;
558 ZSTD_outBuffer onb;
559
560 if (ctx && t) {
561 gsize dlen = 0;
562
563 if (t->len == 0) {
564 return lua_zstd_push_error (L, ZSTD_error_init_missing);
565 }
566
567 inb.size = t->len;
568 inb.pos = 0;
569 inb.src = (const void*)t->start;
570
571 onb.pos = 0;
572 onb.size = ZSTD_DStreamInSize (); /* Initial guess */
573 onb.dst = NULL;
574
575 for (;;) {
576 if ((onb.dst = g_realloc (onb.dst, onb.size)) == NULL) {
577 return lua_zstd_push_error (L, ZSTD_error_memory_allocation);
578 }
579
580 dlen = onb.size;
581
582 int res = ZSTD_decompressStream (ctx, &onb, &inb);
583
584 if (res == 0) {
585 /* All done */
586 break;
587 }
588
589 if ((err = ZSTD_getErrorCode (res))) {
590 break;
591 }
592
593 onb.size *= 2;
594 res += dlen; /* Hint returned by compression routine */
595
596 /* Either double the buffer, or use the hint provided */
597 if (onb.size < res) {
598 onb.size = res;
599 }
600 }
601 }
602 else {
603 return luaL_error (L, "invalid arguments");
604 }
605
606 if (err) {
607 return lua_zstd_push_error (L, err);
608 }
609
610 lua_new_text (L, onb.dst, onb.pos, TRUE);
611
612 return 1;
613 }
614
615 static gint
lua_load_zstd(lua_State * L)616 lua_load_zstd (lua_State * L)
617 {
618 lua_newtable (L);
619 luaL_register (L, NULL, zstd_compress_lib_f);
620
621 return 1;
622 }
623
624 void
luaopen_compress(lua_State * L)625 luaopen_compress (lua_State *L)
626 {
627 rspamd_lua_new_class (L, "rspamd{zstd_compress}", zstd_compress_lib_m);
628 rspamd_lua_new_class (L, "rspamd{zstd_decompress}", zstd_decompress_lib_m);
629 lua_pop (L, 2);
630
631 rspamd_lua_add_preload (L, "rspamd_zstd", lua_load_zstd);
632 }
633