1 /*-
2  * Copyright 2016 Vsevolod Stakhov
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "lua_common.h"
17 #include "message.h"
18 #include "libutil/multipattern.h"
19 
20 /***
21  * @module rspamd_trie
22  * Rspamd trie module provides the data structure suitable for searching of many
23  * patterns in arbitrary texts (or binary chunks). The algorithmic complexity of
24  * this algorithm is at most O(n + m + z), where `n` is the length of text, `m` is a length of pattern and `z` is a number of patterns in the text.
25  *
26  * Here is a typical example of trie usage:
27  * @example
28 local rspamd_trie = require "rspamd_trie"
29 local patterns = {'aab', 'ab', 'bcd\0ef'}
30 
31 local trie = rspamd_trie.create(patterns)
32 
33 local function trie_callback(number, pos)
34 	print('Matched pattern number ' .. tostring(number) .. ' at pos: ' .. tostring(pos))
35 end
36 
37 trie:match('some big text', trie_callback)
38  */
39 
40 /* Suffix trie */
41 LUA_FUNCTION_DEF (trie, create);
42 LUA_FUNCTION_DEF (trie, has_hyperscan);
43 LUA_FUNCTION_DEF (trie, match);
44 LUA_FUNCTION_DEF (trie, search_mime);
45 LUA_FUNCTION_DEF (trie, search_rawmsg);
46 LUA_FUNCTION_DEF (trie, search_rawbody);
47 LUA_FUNCTION_DEF (trie, destroy);
48 
49 static const struct luaL_reg trielib_m[] = {
50 	LUA_INTERFACE_DEF (trie, match),
51 	LUA_INTERFACE_DEF (trie, search_mime),
52 	LUA_INTERFACE_DEF (trie, search_rawmsg),
53 	LUA_INTERFACE_DEF (trie, search_rawbody),
54 	{"__tostring", rspamd_lua_class_tostring},
55 	{"__gc", lua_trie_destroy},
56 	{NULL, NULL}
57 };
58 static const struct luaL_reg trielib_f[] = {
59 	LUA_INTERFACE_DEF (trie, create),
60 	LUA_INTERFACE_DEF (trie, has_hyperscan),
61 	{NULL, NULL}
62 };
63 
64 static struct rspamd_multipattern *
lua_check_trie(lua_State * L,gint idx)65 lua_check_trie (lua_State * L, gint idx)
66 {
67 	void *ud = rspamd_lua_check_udata (L, 1, "rspamd{trie}");
68 
69 	luaL_argcheck (L, ud != NULL, 1, "'trie' expected");
70 	return ud ? *((struct rspamd_multipattern **)ud) : NULL;
71 }
72 
73 static gint
lua_trie_destroy(lua_State * L)74 lua_trie_destroy (lua_State *L)
75 {
76 	struct rspamd_multipattern *trie = lua_check_trie (L, 1);
77 
78 	if (trie) {
79 		rspamd_multipattern_destroy (trie);
80 	}
81 
82 	return 0;
83 }
84 
85 /***
86  * function trie.has_hyperscan()
87  * Checks for hyperscan support
88  *
89  * @return {bool} true if hyperscan is supported
90  */
91 static gint
lua_trie_has_hyperscan(lua_State * L)92 lua_trie_has_hyperscan (lua_State *L)
93 {
94 	lua_pushboolean (L, rspamd_multipattern_has_hyperscan ());
95 	return 1;
96 }
97 
98 /***
99  * function trie.create(patterns, [flags])
100  * Creates new trie data structure
101  * @param {table} array of string patterns
102  * @return {trie} new trie object
103  */
104 static gint
lua_trie_create(lua_State * L)105 lua_trie_create (lua_State *L)
106 {
107 	struct rspamd_multipattern *trie, **ptrie;
108 	gint npat = 0, flags = RSPAMD_MULTIPATTERN_ICASE|RSPAMD_MULTIPATTERN_GLOB;
109 	GError *err = NULL;
110 
111 	if (lua_isnumber (L, 2)) {
112 		flags = lua_tointeger (L, 2);
113 	}
114 
115 	if (!lua_istable (L, 1)) {
116 		return luaL_error (L, "lua trie expects array of patterns for now");
117 	}
118 	else {
119 		lua_pushvalue (L, 1);
120 		lua_pushnil (L);
121 
122 		while (lua_next (L, -2) != 0) {
123 			if (lua_isstring (L, -1)) {
124 				npat ++;
125 			}
126 
127 			lua_pop (L, 1);
128 		}
129 
130 		trie = rspamd_multipattern_create_sized (npat, flags);
131 		lua_pushnil (L);
132 
133 		while (lua_next (L, -2) != 0) {
134 			if (lua_isstring (L, -1)) {
135 				const gchar *pat;
136 				gsize patlen;
137 
138 				pat = lua_tolstring (L, -1, &patlen);
139 				rspamd_multipattern_add_pattern_len (trie, pat, patlen, flags);
140 			}
141 
142 			lua_pop (L, 1);
143 		}
144 
145 		lua_pop (L, 1); /* table */
146 
147 		if (!rspamd_multipattern_compile (trie, &err)) {
148 			msg_err ("cannot compile multipattern: %e", err);
149 			g_error_free (err);
150 			rspamd_multipattern_destroy (trie);
151 			lua_pushnil (L);
152 		}
153 		else {
154 			ptrie = lua_newuserdata (L, sizeof (void *));
155 			rspamd_lua_setclass (L, "rspamd{trie}", -1);
156 			*ptrie = trie;
157 		}
158 	}
159 
160 	return 1;
161 }
162 
163 #define PUSH_TRIE_MATCH(L, start, end, report_start) do { \
164 	if (report_start) { \
165 		lua_createtable (L, 2, 0); \
166 		lua_pushinteger (L, (start)); \
167 		lua_rawseti (L, -2, 1); \
168 		lua_pushinteger (L, (end)); \
169 		lua_rawseti (L, -2, 2); \
170 	} \
171 	else { \
172 		lua_pushinteger (L, (end)); \
173 	} \
174 } while(0)
175 
176 /* Normal callback type */
177 static gint
lua_trie_lua_cb_callback(struct rspamd_multipattern * mp,guint strnum,gint match_start,gint textpos,const gchar * text,gsize len,void * context)178 lua_trie_lua_cb_callback (struct rspamd_multipattern *mp,
179 						  guint strnum,
180 						  gint match_start,
181 						  gint textpos,
182 						  const gchar *text,
183 						  gsize len,
184 						  void *context)
185 {
186 	lua_State *L = context;
187 	gint ret;
188 
189 	gboolean report_start = lua_toboolean (L, -1);
190 
191 	/* Function */
192 	lua_pushvalue (L, 3);
193 	lua_pushinteger (L, strnum + 1);
194 
195 	PUSH_TRIE_MATCH (L, match_start, textpos, report_start);
196 
197 	if (lua_pcall (L, 2, 1, 0) != 0) {
198 		msg_info ("call to trie callback has failed: %s",
199 			lua_tostring (L, -1));
200 		lua_pop (L, 1);
201 
202 		return 1;
203 	}
204 
205 	ret = lua_tonumber (L, -1);
206 	lua_pop (L, 1);
207 
208 	return ret;
209 }
210 
211 /* Table like callback, expect result table on top of the stack */
212 static gint
lua_trie_table_callback(struct rspamd_multipattern * mp,guint strnum,gint match_start,gint textpos,const gchar * text,gsize len,void * context)213 lua_trie_table_callback (struct rspamd_multipattern *mp,
214 				   guint strnum,
215 				   gint match_start,
216 				   gint textpos,
217 				   const gchar *text,
218 				   gsize len,
219 				   void *context)
220 {
221 	lua_State *L = context;
222 
223 	gint report_start = lua_toboolean (L, -2);
224 	/* Set table, indexed by pattern number */
225 	lua_rawgeti (L, -1, strnum + 1);
226 
227 	if (lua_istable (L, -1)) {
228 		/* Already have table, add offset */
229 		gsize last = rspamd_lua_table_size (L, -1);
230 		PUSH_TRIE_MATCH (L, match_start, textpos, report_start);
231 		lua_rawseti (L, -2, last + 1);
232 		/* Remove table from the stack */
233 		lua_pop (L, 1);
234 	}
235 	else {
236 		/* Pop none */
237 		lua_pop (L, 1);
238 		/* New table */
239 		lua_newtable (L);
240 		PUSH_TRIE_MATCH (L, match_start, textpos, report_start);
241 		lua_rawseti (L, -2, 1);
242 		lua_rawseti (L, -2, strnum + 1);
243 	}
244 
245 	return 0;
246 }
247 
248 /*
249  * We assume that callback argument is at pos 3 and icase is in position 4
250  */
251 static gint
lua_trie_search_str(lua_State * L,struct rspamd_multipattern * trie,const gchar * str,gsize len,rspamd_multipattern_cb_t cb)252 lua_trie_search_str (lua_State *L, struct rspamd_multipattern *trie,
253 		const gchar *str, gsize len, rspamd_multipattern_cb_t cb)
254 {
255 	gint ret;
256 	guint nfound = 0;
257 
258 	if ((ret = rspamd_multipattern_lookup (trie, str, len,
259 			cb, L, &nfound)) == 0) {
260 		return nfound;
261 	}
262 
263 	return ret;
264 }
265 
266 /***
267  * @method trie:match(input, [cb][, report_start])
268  * Search for patterns in `input` invoking `cb` optionally ignoring case
269  * @param {table or string} input one or several (if `input` is an array) strings of input text
270  * @param {function} cb callback called on each pattern match in form `function (idx, pos)` where `idx` is a numeric index of pattern (starting from 1) and `pos` is a numeric offset where the pattern ends
271  * @param {boolean} report_start report both start and end offset when matching patterns
272  * @return {boolean} `true` if any pattern has been found (`cb` might be called multiple times however). If `cb` is not defined then it returns a table of match positions indexed by pattern number
273  */
274 static gint
lua_trie_match(lua_State * L)275 lua_trie_match (lua_State *L)
276 {
277 	LUA_TRACE_POINT;
278 	struct rspamd_multipattern *trie = lua_check_trie (L, 1);
279 	const gchar *text;
280 	gsize len;
281 	gboolean found = FALSE, report_start = FALSE;
282 	struct rspamd_lua_text *t;
283 	rspamd_multipattern_cb_t cb = lua_trie_lua_cb_callback;
284 
285 	gint old_top = lua_gettop (L);
286 
287 	if (trie) {
288 		if (lua_type (L, 3) != LUA_TFUNCTION) {
289 			if (lua_isboolean (L, 3)) {
290 				report_start = lua_toboolean (L, 3);
291 			}
292 
293 			lua_pushboolean (L, report_start);
294 			/* Table like match */
295 			lua_newtable (L);
296 			cb = lua_trie_table_callback;
297 		}
298 		else {
299 			if (lua_isboolean (L, 4)) {
300 				report_start = lua_toboolean (L, 4);
301 			}
302 			lua_pushboolean (L, report_start);
303 		}
304 
305 		if (lua_type (L, 2) == LUA_TTABLE) {
306 			lua_pushvalue (L, 2);
307 			lua_pushnil (L);
308 
309 			while (lua_next (L, -2) != 0) {
310 				if (lua_isstring (L, -1)) {
311 					text = lua_tolstring (L, -1, &len);
312 
313 					if (lua_trie_search_str (L, trie, text, len, cb)) {
314 						found = TRUE;
315 					}
316 				}
317 				else if (lua_isuserdata (L, -1)) {
318 					t = lua_check_text (L, -1);
319 
320 					if (t) {
321 						if (lua_trie_search_str (L, trie, t->start, t->len, cb)) {
322 							found = TRUE;
323 						}
324 					}
325 				}
326 				lua_pop (L, 1);
327 			}
328 		}
329 		else if (lua_type (L, 2) == LUA_TSTRING) {
330 			text = lua_tolstring (L, 2, &len);
331 
332 			if (lua_trie_search_str (L, trie, text, len, cb)) {
333 				found = TRUE;
334 			}
335 		}
336 		else if (lua_type (L, 2) == LUA_TUSERDATA) {
337 			t = lua_check_text (L, 2);
338 
339 			if (t && lua_trie_search_str (L, trie, t->start, t->len, cb)) {
340 				found = TRUE;
341 			}
342 		}
343 	}
344 
345 	if (lua_type (L, 3) == LUA_TFUNCTION) {
346 		lua_settop (L, old_top);
347 		lua_pushboolean (L, found);
348 	}
349 	else {
350 		lua_remove (L, -2);
351 	}
352 
353 	return 1;
354 }
355 
356 /***
357  * @method trie:search_mime(task, cb)
358  * This is a helper mehthod to search pattern within text parts of a message in rspamd task
359  * @param {task} task object
360  * @param {function} cb callback called on each pattern match @see trie:match
361  * @param {boolean} caseless if `true` then match ignores symbols case (ASCII only)
362  * @return {boolean} `true` if any pattern has been found (`cb` might be called multiple times however)
363  */
364 static gint
lua_trie_search_mime(lua_State * L)365 lua_trie_search_mime (lua_State *L)
366 {
367 	LUA_TRACE_POINT;
368 	struct rspamd_multipattern *trie = lua_check_trie (L, 1);
369 	struct rspamd_task *task = lua_check_task (L, 2);
370 	struct rspamd_mime_text_part *part;
371 	const gchar *text;
372 	gsize len, i;
373 	gboolean found = FALSE;
374 	rspamd_multipattern_cb_t cb = lua_trie_lua_cb_callback;
375 
376 	if (trie && task) {
377 		PTR_ARRAY_FOREACH (MESSAGE_FIELD (task, text_parts), i, part) {
378 			if (!IS_TEXT_PART_EMPTY (part) && part->utf_content.len > 0) {
379 				text = part->utf_content.begin;
380 				len = part->utf_content.len;
381 
382 				if (lua_trie_search_str (L, trie, text, len, cb) != 0) {
383 					found = TRUE;
384 				}
385 			}
386 		}
387 	}
388 
389 	lua_pushboolean (L, found);
390 	return 1;
391 }
392 
393 /***
394  * @method trie:search_rawmsg(task, cb[, caseless])
395  * This is a helper mehthod to search pattern within the whole undecoded content of rspamd task
396  * @param {task} task object
397  * @param {function} cb callback called on each pattern match @see trie:match
398  * @param {boolean} caseless if `true` then match ignores symbols case (ASCII only)
399  * @return {boolean} `true` if any pattern has been found (`cb` might be called multiple times however)
400  */
401 static gint
lua_trie_search_rawmsg(lua_State * L)402 lua_trie_search_rawmsg (lua_State *L)
403 {
404 	LUA_TRACE_POINT;
405 	struct rspamd_multipattern *trie = lua_check_trie (L, 1);
406 	struct rspamd_task *task = lua_check_task (L, 2);
407 	const gchar *text;
408 	gsize len;
409 	gboolean found = FALSE;
410 
411 	if (trie && task) {
412 		text = task->msg.begin;
413 		len = task->msg.len;
414 
415 		if (lua_trie_search_str (L, trie, text, len, lua_trie_lua_cb_callback) != 0) {
416 			found = TRUE;
417 		}
418 	}
419 
420 	lua_pushboolean (L, found);
421 	return 1;
422 }
423 
424 /***
425  * @method trie:search_rawbody(task, cb[, caseless])
426  * This is a helper mehthod to search pattern within the whole undecoded content of task's body (not including headers)
427  * @param {task} task object
428  * @param {function} cb callback called on each pattern match @see trie:match
429  * @param {boolean} caseless if `true` then match ignores symbols case (ASCII only)
430  * @return {boolean} `true` if any pattern has been found (`cb` might be called multiple times however)
431  */
432 static gint
lua_trie_search_rawbody(lua_State * L)433 lua_trie_search_rawbody (lua_State *L)
434 {
435 	LUA_TRACE_POINT;
436 	struct rspamd_multipattern *trie = lua_check_trie (L, 1);
437 	struct rspamd_task *task = lua_check_task (L, 2);
438 	const gchar *text;
439 	gsize len;
440 	gboolean found = FALSE;
441 
442 	if (trie && task) {
443 		if (MESSAGE_FIELD (task, raw_headers_content).len > 0) {
444 			text = task->msg.begin + MESSAGE_FIELD (task, raw_headers_content).len;
445 			len = task->msg.len - MESSAGE_FIELD (task, raw_headers_content).len;
446 		}
447 		else {
448 			/* Treat as raw message */
449 			text = task->msg.begin;
450 			len = task->msg.len;
451 		}
452 
453 		if (lua_trie_search_str (L, trie, text, len, lua_trie_lua_cb_callback) != 0) {
454 			found = TRUE;
455 		}
456 	}
457 
458 	lua_pushboolean (L, found);
459 	return 1;
460 }
461 
462 static gint
lua_load_trie(lua_State * L)463 lua_load_trie (lua_State *L)
464 {
465 	lua_newtable (L);
466 
467 	/* Flags */
468 	lua_pushstring (L, "flags");
469 	lua_newtable (L);
470 
471 	lua_pushinteger (L, RSPAMD_MULTIPATTERN_GLOB);
472 	lua_setfield (L, -2, "glob");
473 	lua_pushinteger (L, RSPAMD_MULTIPATTERN_RE);
474 	lua_setfield (L, -2, "re");
475 	lua_pushinteger (L, RSPAMD_MULTIPATTERN_ICASE);
476 	lua_setfield (L, -2, "icase");
477 	lua_pushinteger (L, RSPAMD_MULTIPATTERN_UTF8);
478 	lua_setfield (L, -2, "utf8");
479 	lua_pushinteger (L, RSPAMD_MULTIPATTERN_TLD);
480 	lua_setfield (L, -2, "tld");
481 	lua_pushinteger (L, RSPAMD_MULTIPATTERN_DOTALL);
482 	lua_setfield (L, -2, "dot_all");
483 	lua_pushinteger (L, RSPAMD_MULTIPATTERN_SINGLEMATCH);
484 	lua_setfield (L, -2, "single_match");
485 	lua_pushinteger (L, RSPAMD_MULTIPATTERN_NO_START);
486 	lua_setfield (L, -2, "no_start");
487 	lua_settable (L, -3);
488 
489 	/* Main content */
490 	luaL_register (L, NULL, trielib_f);
491 
492 	return 1;
493 }
494 
495 void
luaopen_trie(lua_State * L)496 luaopen_trie (lua_State * L)
497 {
498 	rspamd_lua_new_class (L, "rspamd{trie}", trielib_m);
499 	lua_pop (L, 1);
500 	rspamd_lua_add_preload (L, "rspamd_trie", lua_load_trie);
501 }
502