1 /*-
2 * Copyright 2016 Vsevolod Stakhov
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "lua_common.h"
17 #include "message.h"
18 #include "libutil/multipattern.h"
19
20 /***
21 * @module rspamd_trie
22 * Rspamd trie module provides the data structure suitable for searching of many
23 * patterns in arbitrary texts (or binary chunks). The algorithmic complexity of
24 * this algorithm is at most O(n + m + z), where `n` is the length of text, `m` is a length of pattern and `z` is a number of patterns in the text.
25 *
26 * Here is a typical example of trie usage:
27 * @example
28 local rspamd_trie = require "rspamd_trie"
29 local patterns = {'aab', 'ab', 'bcd\0ef'}
30
31 local trie = rspamd_trie.create(patterns)
32
33 local function trie_callback(number, pos)
34 print('Matched pattern number ' .. tostring(number) .. ' at pos: ' .. tostring(pos))
35 end
36
37 trie:match('some big text', trie_callback)
38 */
39
40 /* Suffix trie */
41 LUA_FUNCTION_DEF (trie, create);
42 LUA_FUNCTION_DEF (trie, has_hyperscan);
43 LUA_FUNCTION_DEF (trie, match);
44 LUA_FUNCTION_DEF (trie, search_mime);
45 LUA_FUNCTION_DEF (trie, search_rawmsg);
46 LUA_FUNCTION_DEF (trie, search_rawbody);
47 LUA_FUNCTION_DEF (trie, destroy);
48
49 static const struct luaL_reg trielib_m[] = {
50 LUA_INTERFACE_DEF (trie, match),
51 LUA_INTERFACE_DEF (trie, search_mime),
52 LUA_INTERFACE_DEF (trie, search_rawmsg),
53 LUA_INTERFACE_DEF (trie, search_rawbody),
54 {"__tostring", rspamd_lua_class_tostring},
55 {"__gc", lua_trie_destroy},
56 {NULL, NULL}
57 };
58 static const struct luaL_reg trielib_f[] = {
59 LUA_INTERFACE_DEF (trie, create),
60 LUA_INTERFACE_DEF (trie, has_hyperscan),
61 {NULL, NULL}
62 };
63
64 static struct rspamd_multipattern *
lua_check_trie(lua_State * L,gint idx)65 lua_check_trie (lua_State * L, gint idx)
66 {
67 void *ud = rspamd_lua_check_udata (L, 1, "rspamd{trie}");
68
69 luaL_argcheck (L, ud != NULL, 1, "'trie' expected");
70 return ud ? *((struct rspamd_multipattern **)ud) : NULL;
71 }
72
73 static gint
lua_trie_destroy(lua_State * L)74 lua_trie_destroy (lua_State *L)
75 {
76 struct rspamd_multipattern *trie = lua_check_trie (L, 1);
77
78 if (trie) {
79 rspamd_multipattern_destroy (trie);
80 }
81
82 return 0;
83 }
84
85 /***
86 * function trie.has_hyperscan()
87 * Checks for hyperscan support
88 *
89 * @return {bool} true if hyperscan is supported
90 */
91 static gint
lua_trie_has_hyperscan(lua_State * L)92 lua_trie_has_hyperscan (lua_State *L)
93 {
94 lua_pushboolean (L, rspamd_multipattern_has_hyperscan ());
95 return 1;
96 }
97
98 /***
99 * function trie.create(patterns, [flags])
100 * Creates new trie data structure
101 * @param {table} array of string patterns
102 * @return {trie} new trie object
103 */
104 static gint
lua_trie_create(lua_State * L)105 lua_trie_create (lua_State *L)
106 {
107 struct rspamd_multipattern *trie, **ptrie;
108 gint npat = 0, flags = RSPAMD_MULTIPATTERN_ICASE|RSPAMD_MULTIPATTERN_GLOB;
109 GError *err = NULL;
110
111 if (lua_isnumber (L, 2)) {
112 flags = lua_tointeger (L, 2);
113 }
114
115 if (!lua_istable (L, 1)) {
116 return luaL_error (L, "lua trie expects array of patterns for now");
117 }
118 else {
119 lua_pushvalue (L, 1);
120 lua_pushnil (L);
121
122 while (lua_next (L, -2) != 0) {
123 if (lua_isstring (L, -1)) {
124 npat ++;
125 }
126
127 lua_pop (L, 1);
128 }
129
130 trie = rspamd_multipattern_create_sized (npat, flags);
131 lua_pushnil (L);
132
133 while (lua_next (L, -2) != 0) {
134 if (lua_isstring (L, -1)) {
135 const gchar *pat;
136 gsize patlen;
137
138 pat = lua_tolstring (L, -1, &patlen);
139 rspamd_multipattern_add_pattern_len (trie, pat, patlen, flags);
140 }
141
142 lua_pop (L, 1);
143 }
144
145 lua_pop (L, 1); /* table */
146
147 if (!rspamd_multipattern_compile (trie, &err)) {
148 msg_err ("cannot compile multipattern: %e", err);
149 g_error_free (err);
150 rspamd_multipattern_destroy (trie);
151 lua_pushnil (L);
152 }
153 else {
154 ptrie = lua_newuserdata (L, sizeof (void *));
155 rspamd_lua_setclass (L, "rspamd{trie}", -1);
156 *ptrie = trie;
157 }
158 }
159
160 return 1;
161 }
162
163 #define PUSH_TRIE_MATCH(L, start, end, report_start) do { \
164 if (report_start) { \
165 lua_createtable (L, 2, 0); \
166 lua_pushinteger (L, (start)); \
167 lua_rawseti (L, -2, 1); \
168 lua_pushinteger (L, (end)); \
169 lua_rawseti (L, -2, 2); \
170 } \
171 else { \
172 lua_pushinteger (L, (end)); \
173 } \
174 } while(0)
175
176 /* Normal callback type */
177 static gint
lua_trie_lua_cb_callback(struct rspamd_multipattern * mp,guint strnum,gint match_start,gint textpos,const gchar * text,gsize len,void * context)178 lua_trie_lua_cb_callback (struct rspamd_multipattern *mp,
179 guint strnum,
180 gint match_start,
181 gint textpos,
182 const gchar *text,
183 gsize len,
184 void *context)
185 {
186 lua_State *L = context;
187 gint ret;
188
189 gboolean report_start = lua_toboolean (L, -1);
190
191 /* Function */
192 lua_pushvalue (L, 3);
193 lua_pushinteger (L, strnum + 1);
194
195 PUSH_TRIE_MATCH (L, match_start, textpos, report_start);
196
197 if (lua_pcall (L, 2, 1, 0) != 0) {
198 msg_info ("call to trie callback has failed: %s",
199 lua_tostring (L, -1));
200 lua_pop (L, 1);
201
202 return 1;
203 }
204
205 ret = lua_tonumber (L, -1);
206 lua_pop (L, 1);
207
208 return ret;
209 }
210
211 /* Table like callback, expect result table on top of the stack */
212 static gint
lua_trie_table_callback(struct rspamd_multipattern * mp,guint strnum,gint match_start,gint textpos,const gchar * text,gsize len,void * context)213 lua_trie_table_callback (struct rspamd_multipattern *mp,
214 guint strnum,
215 gint match_start,
216 gint textpos,
217 const gchar *text,
218 gsize len,
219 void *context)
220 {
221 lua_State *L = context;
222
223 gint report_start = lua_toboolean (L, -2);
224 /* Set table, indexed by pattern number */
225 lua_rawgeti (L, -1, strnum + 1);
226
227 if (lua_istable (L, -1)) {
228 /* Already have table, add offset */
229 gsize last = rspamd_lua_table_size (L, -1);
230 PUSH_TRIE_MATCH (L, match_start, textpos, report_start);
231 lua_rawseti (L, -2, last + 1);
232 /* Remove table from the stack */
233 lua_pop (L, 1);
234 }
235 else {
236 /* Pop none */
237 lua_pop (L, 1);
238 /* New table */
239 lua_newtable (L);
240 PUSH_TRIE_MATCH (L, match_start, textpos, report_start);
241 lua_rawseti (L, -2, 1);
242 lua_rawseti (L, -2, strnum + 1);
243 }
244
245 return 0;
246 }
247
248 /*
249 * We assume that callback argument is at pos 3 and icase is in position 4
250 */
251 static gint
lua_trie_search_str(lua_State * L,struct rspamd_multipattern * trie,const gchar * str,gsize len,rspamd_multipattern_cb_t cb)252 lua_trie_search_str (lua_State *L, struct rspamd_multipattern *trie,
253 const gchar *str, gsize len, rspamd_multipattern_cb_t cb)
254 {
255 gint ret;
256 guint nfound = 0;
257
258 if ((ret = rspamd_multipattern_lookup (trie, str, len,
259 cb, L, &nfound)) == 0) {
260 return nfound;
261 }
262
263 return ret;
264 }
265
266 /***
267 * @method trie:match(input, [cb][, report_start])
268 * Search for patterns in `input` invoking `cb` optionally ignoring case
269 * @param {table or string} input one or several (if `input` is an array) strings of input text
270 * @param {function} cb callback called on each pattern match in form `function (idx, pos)` where `idx` is a numeric index of pattern (starting from 1) and `pos` is a numeric offset where the pattern ends
271 * @param {boolean} report_start report both start and end offset when matching patterns
272 * @return {boolean} `true` if any pattern has been found (`cb` might be called multiple times however). If `cb` is not defined then it returns a table of match positions indexed by pattern number
273 */
274 static gint
lua_trie_match(lua_State * L)275 lua_trie_match (lua_State *L)
276 {
277 LUA_TRACE_POINT;
278 struct rspamd_multipattern *trie = lua_check_trie (L, 1);
279 const gchar *text;
280 gsize len;
281 gboolean found = FALSE, report_start = FALSE;
282 struct rspamd_lua_text *t;
283 rspamd_multipattern_cb_t cb = lua_trie_lua_cb_callback;
284
285 gint old_top = lua_gettop (L);
286
287 if (trie) {
288 if (lua_type (L, 3) != LUA_TFUNCTION) {
289 if (lua_isboolean (L, 3)) {
290 report_start = lua_toboolean (L, 3);
291 }
292
293 lua_pushboolean (L, report_start);
294 /* Table like match */
295 lua_newtable (L);
296 cb = lua_trie_table_callback;
297 }
298 else {
299 if (lua_isboolean (L, 4)) {
300 report_start = lua_toboolean (L, 4);
301 }
302 lua_pushboolean (L, report_start);
303 }
304
305 if (lua_type (L, 2) == LUA_TTABLE) {
306 lua_pushvalue (L, 2);
307 lua_pushnil (L);
308
309 while (lua_next (L, -2) != 0) {
310 if (lua_isstring (L, -1)) {
311 text = lua_tolstring (L, -1, &len);
312
313 if (lua_trie_search_str (L, trie, text, len, cb)) {
314 found = TRUE;
315 }
316 }
317 else if (lua_isuserdata (L, -1)) {
318 t = lua_check_text (L, -1);
319
320 if (t) {
321 if (lua_trie_search_str (L, trie, t->start, t->len, cb)) {
322 found = TRUE;
323 }
324 }
325 }
326 lua_pop (L, 1);
327 }
328 }
329 else if (lua_type (L, 2) == LUA_TSTRING) {
330 text = lua_tolstring (L, 2, &len);
331
332 if (lua_trie_search_str (L, trie, text, len, cb)) {
333 found = TRUE;
334 }
335 }
336 else if (lua_type (L, 2) == LUA_TUSERDATA) {
337 t = lua_check_text (L, 2);
338
339 if (t && lua_trie_search_str (L, trie, t->start, t->len, cb)) {
340 found = TRUE;
341 }
342 }
343 }
344
345 if (lua_type (L, 3) == LUA_TFUNCTION) {
346 lua_settop (L, old_top);
347 lua_pushboolean (L, found);
348 }
349 else {
350 lua_remove (L, -2);
351 }
352
353 return 1;
354 }
355
356 /***
357 * @method trie:search_mime(task, cb)
358 * This is a helper mehthod to search pattern within text parts of a message in rspamd task
359 * @param {task} task object
360 * @param {function} cb callback called on each pattern match @see trie:match
361 * @param {boolean} caseless if `true` then match ignores symbols case (ASCII only)
362 * @return {boolean} `true` if any pattern has been found (`cb` might be called multiple times however)
363 */
364 static gint
lua_trie_search_mime(lua_State * L)365 lua_trie_search_mime (lua_State *L)
366 {
367 LUA_TRACE_POINT;
368 struct rspamd_multipattern *trie = lua_check_trie (L, 1);
369 struct rspamd_task *task = lua_check_task (L, 2);
370 struct rspamd_mime_text_part *part;
371 const gchar *text;
372 gsize len, i;
373 gboolean found = FALSE;
374 rspamd_multipattern_cb_t cb = lua_trie_lua_cb_callback;
375
376 if (trie && task) {
377 PTR_ARRAY_FOREACH (MESSAGE_FIELD (task, text_parts), i, part) {
378 if (!IS_TEXT_PART_EMPTY (part) && part->utf_content.len > 0) {
379 text = part->utf_content.begin;
380 len = part->utf_content.len;
381
382 if (lua_trie_search_str (L, trie, text, len, cb) != 0) {
383 found = TRUE;
384 }
385 }
386 }
387 }
388
389 lua_pushboolean (L, found);
390 return 1;
391 }
392
393 /***
394 * @method trie:search_rawmsg(task, cb[, caseless])
395 * This is a helper mehthod to search pattern within the whole undecoded content of rspamd task
396 * @param {task} task object
397 * @param {function} cb callback called on each pattern match @see trie:match
398 * @param {boolean} caseless if `true` then match ignores symbols case (ASCII only)
399 * @return {boolean} `true` if any pattern has been found (`cb` might be called multiple times however)
400 */
401 static gint
lua_trie_search_rawmsg(lua_State * L)402 lua_trie_search_rawmsg (lua_State *L)
403 {
404 LUA_TRACE_POINT;
405 struct rspamd_multipattern *trie = lua_check_trie (L, 1);
406 struct rspamd_task *task = lua_check_task (L, 2);
407 const gchar *text;
408 gsize len;
409 gboolean found = FALSE;
410
411 if (trie && task) {
412 text = task->msg.begin;
413 len = task->msg.len;
414
415 if (lua_trie_search_str (L, trie, text, len, lua_trie_lua_cb_callback) != 0) {
416 found = TRUE;
417 }
418 }
419
420 lua_pushboolean (L, found);
421 return 1;
422 }
423
424 /***
425 * @method trie:search_rawbody(task, cb[, caseless])
426 * This is a helper mehthod to search pattern within the whole undecoded content of task's body (not including headers)
427 * @param {task} task object
428 * @param {function} cb callback called on each pattern match @see trie:match
429 * @param {boolean} caseless if `true` then match ignores symbols case (ASCII only)
430 * @return {boolean} `true` if any pattern has been found (`cb` might be called multiple times however)
431 */
432 static gint
lua_trie_search_rawbody(lua_State * L)433 lua_trie_search_rawbody (lua_State *L)
434 {
435 LUA_TRACE_POINT;
436 struct rspamd_multipattern *trie = lua_check_trie (L, 1);
437 struct rspamd_task *task = lua_check_task (L, 2);
438 const gchar *text;
439 gsize len;
440 gboolean found = FALSE;
441
442 if (trie && task) {
443 if (MESSAGE_FIELD (task, raw_headers_content).len > 0) {
444 text = task->msg.begin + MESSAGE_FIELD (task, raw_headers_content).len;
445 len = task->msg.len - MESSAGE_FIELD (task, raw_headers_content).len;
446 }
447 else {
448 /* Treat as raw message */
449 text = task->msg.begin;
450 len = task->msg.len;
451 }
452
453 if (lua_trie_search_str (L, trie, text, len, lua_trie_lua_cb_callback) != 0) {
454 found = TRUE;
455 }
456 }
457
458 lua_pushboolean (L, found);
459 return 1;
460 }
461
462 static gint
lua_load_trie(lua_State * L)463 lua_load_trie (lua_State *L)
464 {
465 lua_newtable (L);
466
467 /* Flags */
468 lua_pushstring (L, "flags");
469 lua_newtable (L);
470
471 lua_pushinteger (L, RSPAMD_MULTIPATTERN_GLOB);
472 lua_setfield (L, -2, "glob");
473 lua_pushinteger (L, RSPAMD_MULTIPATTERN_RE);
474 lua_setfield (L, -2, "re");
475 lua_pushinteger (L, RSPAMD_MULTIPATTERN_ICASE);
476 lua_setfield (L, -2, "icase");
477 lua_pushinteger (L, RSPAMD_MULTIPATTERN_UTF8);
478 lua_setfield (L, -2, "utf8");
479 lua_pushinteger (L, RSPAMD_MULTIPATTERN_TLD);
480 lua_setfield (L, -2, "tld");
481 lua_pushinteger (L, RSPAMD_MULTIPATTERN_DOTALL);
482 lua_setfield (L, -2, "dot_all");
483 lua_pushinteger (L, RSPAMD_MULTIPATTERN_SINGLEMATCH);
484 lua_setfield (L, -2, "single_match");
485 lua_pushinteger (L, RSPAMD_MULTIPATTERN_NO_START);
486 lua_setfield (L, -2, "no_start");
487 lua_settable (L, -3);
488
489 /* Main content */
490 luaL_register (L, NULL, trielib_f);
491
492 return 1;
493 }
494
495 void
luaopen_trie(lua_State * L)496 luaopen_trie (lua_State * L)
497 {
498 rspamd_lua_new_class (L, "rspamd{trie}", trielib_m);
499 lua_pop (L, 1);
500 rspamd_lua_add_preload (L, "rspamd_trie", lua_load_trie);
501 }
502