1 /*-
2  * Copyright 2019 Vsevolod Stakhov
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *   http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "lua_common.h"
18 #include "lua_tensor.h"
19 #include "contrib/kann/kann.h"
20 
21 /***
22  * @module rspamd_kann
23  * `rspamd_kann` is a Lua interface to kann library
24  */
25 
26 #define KANN_NODE_CLASS "rspamd{kann_node}"
27 #define KANN_NETWORK_CLASS "rspamd{kann}"
28 
29 /* Simple macros to define behaviour */
30 #define KANN_LAYER_DEF(name) static int lua_kann_layer_ ## name (lua_State *L)
31 #define KANN_LAYER_INTERFACE(name) {#name, lua_kann_layer_ ## name}
32 
33 #define KANN_TRANSFORM_DEF(name) static int lua_kann_transform_ ## name (lua_State *L)
34 #define KANN_TRANSFORM_INTERFACE(name) {#name, lua_kann_transform_ ## name}
35 
36 #define KANN_LOSS_DEF(name) static int lua_kann_loss_ ## name (lua_State *L)
37 #define KANN_LOSS_INTERFACE(name) {#name, lua_kann_loss_ ## name}
38 
39 #define KANN_NEW_DEF(name) static int lua_kann_new_ ## name (lua_State *L)
40 #define KANN_NEW_INTERFACE(name) {#name, lua_kann_new_ ## name}
41 
42 
43 /*
44  * Forwarded declarations
45  */
46 static kad_node_t *lua_check_kann_node (lua_State *L, int pos);
47 
48 /* Layers */
49 KANN_LAYER_DEF(input);
50 KANN_LAYER_DEF(dense);
51 KANN_LAYER_DEF(layernorm);
52 KANN_LAYER_DEF(rnn);
53 KANN_LAYER_DEF(lstm);
54 KANN_LAYER_DEF(gru);
55 KANN_LAYER_DEF(conv2d);
56 KANN_LAYER_DEF(conv1d);
57 KANN_LAYER_DEF(cost);
58 
59 static luaL_reg rspamd_kann_layers_f[] = {
60 		KANN_LAYER_INTERFACE(input),
61 		KANN_LAYER_INTERFACE(dense),
62 		KANN_LAYER_INTERFACE(layernorm),
63 		KANN_LAYER_INTERFACE(rnn),
64 		KANN_LAYER_INTERFACE(lstm),
65 		KANN_LAYER_INTERFACE(gru),
66 		KANN_LAYER_INTERFACE(conv2d),
67 		KANN_LAYER_INTERFACE(conv1d),
68 		KANN_LAYER_INTERFACE(cost),
69 		{NULL, NULL},
70 };
71 
72 /* Transition and composition functions */
73 
74 /* General transform */
75 KANN_TRANSFORM_DEF (add);
76 KANN_TRANSFORM_DEF (sub);
77 KANN_TRANSFORM_DEF (mul);
78 KANN_TRANSFORM_DEF (cmul);
79 KANN_TRANSFORM_DEF (matmul);
80 
81 KANN_TRANSFORM_DEF (square);
82 KANN_TRANSFORM_DEF (sigm);
83 KANN_TRANSFORM_DEF (tanh);
84 KANN_TRANSFORM_DEF (relu);
85 KANN_TRANSFORM_DEF (softmax);
86 KANN_TRANSFORM_DEF (1minus);
87 KANN_TRANSFORM_DEF (exp);
88 KANN_TRANSFORM_DEF (log);
89 KANN_TRANSFORM_DEF (sin);
90 static luaL_reg rspamd_kann_transform_f[] = {
91 		KANN_TRANSFORM_INTERFACE (add),
92 		KANN_TRANSFORM_INTERFACE (sub),
93 		KANN_TRANSFORM_INTERFACE (mul),
94 		KANN_TRANSFORM_INTERFACE (cmul),
95 		KANN_TRANSFORM_INTERFACE (matmul),
96 
97 		KANN_TRANSFORM_INTERFACE (square),
98 		KANN_TRANSFORM_INTERFACE (sigm),
99 		KANN_TRANSFORM_INTERFACE (tanh),
100 		KANN_TRANSFORM_INTERFACE (relu),
101 		KANN_TRANSFORM_INTERFACE (softmax),
102 		KANN_TRANSFORM_INTERFACE (1minus),
103 		KANN_TRANSFORM_INTERFACE (exp),
104 		KANN_TRANSFORM_INTERFACE (log),
105 		KANN_TRANSFORM_INTERFACE (sin),
106 		{NULL, NULL},
107 };
108 
109 /* Loss functions */
110 KANN_LOSS_DEF (mse);
111 KANN_LOSS_DEF (ce_multi);
112 KANN_LOSS_DEF (ce_bin);
113 KANN_LOSS_DEF (ce_bin_neg);
114 KANN_LOSS_DEF (ce_multi_weighted);
115 static luaL_reg rspamd_kann_loss_f[] = {
116 		KANN_LOSS_INTERFACE (mse),
117 		KANN_LOSS_INTERFACE (ce_multi),
118 		KANN_LOSS_INTERFACE (ce_bin),
119 		KANN_LOSS_INTERFACE (ce_bin_neg),
120 		KANN_LOSS_INTERFACE (ce_multi_weighted),
121 		{NULL, NULL},
122 };
123 
124 /* Creation functions */
125 KANN_NEW_DEF (leaf);
126 KANN_NEW_DEF (scalar);
127 KANN_NEW_DEF (weight);
128 KANN_NEW_DEF (bias);
129 KANN_NEW_DEF (weight_conv2d);
130 KANN_NEW_DEF (weight_conv1d);
131 KANN_NEW_DEF (kann);
132 
133 static luaL_reg rspamd_kann_new_f[] = {
134 		KANN_NEW_INTERFACE (leaf),
135 		KANN_NEW_INTERFACE (scalar),
136 		KANN_NEW_INTERFACE (weight),
137 		KANN_NEW_INTERFACE (bias),
138 		KANN_NEW_INTERFACE (weight_conv2d),
139 		KANN_NEW_INTERFACE (weight_conv1d),
140 		KANN_NEW_INTERFACE (kann),
141 		{NULL, NULL},
142 };
143 
144 LUA_FUNCTION_DEF (kann, load);
145 LUA_FUNCTION_DEF (kann, destroy);
146 LUA_FUNCTION_DEF (kann, save);
147 LUA_FUNCTION_DEF (kann, train1);
148 LUA_FUNCTION_DEF (kann, apply1);
149 
150 static luaL_reg rspamd_kann_m[] = {
151 		LUA_INTERFACE_DEF (kann, save),
152 		LUA_INTERFACE_DEF (kann, train1),
153 		LUA_INTERFACE_DEF (kann, apply1),
154 		{"__gc", lua_kann_destroy},
155 		{NULL, NULL},
156 };
157 
158 static int
rspamd_kann_table_to_flags(lua_State * L,int table_pos)159 rspamd_kann_table_to_flags (lua_State *L, int table_pos)
160 {
161 	int result = 0;
162 
163 	lua_pushvalue (L, table_pos);
164 
165 	for (lua_pushnil (L); lua_next (L, -2); lua_pop (L, 1)) {
166 		int fl = lua_tointeger (L, -1);
167 
168 		result |= fl;
169 	}
170 
171 	lua_pop (L, 1);
172 
173 	return result;
174 }
175 
176 static gint
lua_load_kann(lua_State * L)177 lua_load_kann (lua_State * L)
178 {
179 	lua_newtable (L);
180 
181 	/* Flags */
182 	lua_pushstring (L, "flag");
183 	lua_newtable (L);
184 	lua_pushinteger (L, KANN_F_IN);
185 	lua_setfield (L, -2, "in");
186 	lua_pushinteger (L, KANN_F_COST);
187 	lua_setfield (L, -2, "cost");
188 	lua_pushinteger (L, KANN_F_OUT);
189 	lua_setfield (L, -2, "out");
190 	lua_pushinteger (L, KANN_F_TRUTH);
191 	lua_setfield (L, -2, "truth");
192 	lua_settable (L, -3);
193 
194 	/* Cost type */
195 	lua_pushstring (L, "cost");
196 	lua_newtable (L);
197 	/* binary cross-entropy cost, used with sigmoid */
198 	lua_pushinteger (L, KANN_C_CEB);
199 	lua_setfield (L, -2, "ceb");
200 	/* multi-class cross-entropy cost, used with softmax */
201 	lua_pushinteger (L, KANN_C_CEM);
202 	lua_setfield (L, -2, "cem");
203 	/* binary cross-entropy-like cost, used with tanh */
204 	lua_pushinteger (L, KANN_C_CEB_NEG);
205 	lua_setfield (L, -2, "ceb_neg");
206 	lua_pushinteger (L, KANN_C_MSE);
207 	lua_setfield (L, -2, "mse");
208 	lua_settable (L, -3);
209 
210 	/* RNN flag */
211 	lua_pushstring (L, "rnn");
212 	lua_newtable (L);
213 	/* apply layer normalization */
214 	lua_pushinteger (L, KANN_RNN_NORM);
215 	lua_setfield (L, -2, "norm");
216 	/* take the initial hidden values as variables */
217 	lua_pushinteger (L, KANN_RNN_VAR_H0);
218 	lua_setfield (L, -2, "var_h0");
219 	lua_settable (L, -3);
220 
221 	/* Layers */
222 	lua_pushstring (L, "layer");
223 	lua_newtable (L);
224 	luaL_register (L, NULL, rspamd_kann_layers_f);
225 	lua_settable (L, -3);
226 
227 	/* Transforms */
228 	lua_pushstring (L, "transform");
229 	lua_newtable (L);
230 	luaL_register (L, NULL, rspamd_kann_transform_f);
231 	lua_settable (L, -3);
232 
233 	/* Cost */
234 	lua_pushstring (L, "loss");
235 	lua_newtable (L);
236 	luaL_register (L, NULL, rspamd_kann_loss_f);
237 	lua_settable (L, -3);
238 
239 	/* Create functions */
240 	lua_pushstring (L, "new");
241 	lua_newtable (L);
242 	luaL_register (L, NULL, rspamd_kann_new_f);
243 	lua_settable (L, -3);
244 
245 	/* Load ann from memory or file */
246 	lua_pushstring (L, "load");
247 	lua_pushcfunction (L, lua_kann_load);
248 	lua_settable (L, -3);
249 
250 	return 1;
251 }
252 
253 static kad_node_t *
lua_check_kann_node(lua_State * L,int pos)254 lua_check_kann_node (lua_State *L, int pos)
255 {
256 	void *ud = rspamd_lua_check_udata (L, pos, KANN_NODE_CLASS);
257 	luaL_argcheck (L, ud != NULL, pos, "'kann_node' expected");
258 	return ud ? *((kad_node_t **)ud) : NULL;
259 }
260 
261 static kann_t *
lua_check_kann(lua_State * L,int pos)262 lua_check_kann (lua_State *L, int pos)
263 {
264 	void *ud = rspamd_lua_check_udata (L, pos, KANN_NETWORK_CLASS);
265 	luaL_argcheck (L, ud != NULL, pos, "'kann' expected");
266 	return ud ? *((kann_t **)ud) : NULL;
267 }
268 
luaopen_kann(lua_State * L)269 void luaopen_kann (lua_State *L)
270 {
271 	/* Metatables */
272 	rspamd_lua_new_class (L, KANN_NODE_CLASS, NULL); /* TODO: add methods */
273 	lua_pop (L, 1); /* No need in metatable... */
274 	rspamd_lua_new_class (L, KANN_NETWORK_CLASS, rspamd_kann_m);
275 	lua_pop (L, 1); /* No need in metatable... */
276 	rspamd_lua_add_preload (L, "rspamd_kann", lua_load_kann);
277 	lua_settop (L, 0);
278 }
279 
280 /* Layers implementation */
281 #define PUSH_KAD_NODE(n) do { \
282 	kad_node_t **pt; \
283 	pt = lua_newuserdata (L, sizeof (kad_node_t *)); \
284 	*pt = (n); \
285 	rspamd_lua_setclass (L, KANN_NODE_CLASS, -1); \
286 } while(0)
287 
288 #define PUSH_KAN_NETWORK(n) do { \
289 	kann_t **pn; \
290 	pn = lua_newuserdata (L, sizeof (kann_t *)); \
291 	*pn = (n); \
292 	rspamd_lua_setclass (L, KANN_NETWORK_CLASS, -1); \
293 } while(0)
294 
295 #define PROCESS_KAD_FLAGS(n, pos) do { \
296 	int fl = 0; \
297 	if (lua_type(L, (pos)) == LUA_TTABLE) { fl = rspamd_kann_table_to_flags (L, (pos)); } \
298 	else if (lua_type(L, (pos)) == LUA_TNUMBER) { fl = lua_tointeger (L, (pos)); } \
299 	(n)->ext_flag |= fl; \
300 }while(0)
301 
302 /***
303  * @function kann.layer.input(ninputs[, flags])
304  * Creates an input layer for ANN
305  * @param {int} ninputs number of inputs
306  * @param {table|int} flags optional flags
307  * @return {kann_node} kann node object (should be used to combine ANN)
308 */
309 static int
lua_kann_layer_input(lua_State * L)310 lua_kann_layer_input (lua_State *L)
311 {
312 	gint nnodes = luaL_checkinteger (L, 1);
313 
314 	if (nnodes > 0) {
315 		kad_node_t *t;
316 
317 		t = kann_layer_input (nnodes);
318 
319 		PROCESS_KAD_FLAGS (t, 2);
320 		PUSH_KAD_NODE (t);
321 	}
322 	else {
323 		return luaL_error (L, "invalid arguments, nnodes required");
324 	}
325 
326 	return 1;
327 }
328 
329 /***
330  * @function kann.layer.dense(in, ninputs[, flags])
331  * Creates a dense layer (e.g. for hidden layer)
332  * @param {kann_node} in kann node
333  * @param {int} ninputs number of dense nodes
334  * @param {table|int} flags optional flags
335  * @return {kann_node} kann node object (should be used to combine ANN)
336 */
337 static int
lua_kann_layer_dense(lua_State * L)338 lua_kann_layer_dense (lua_State *L)
339 {
340 	kad_node_t *in = lua_check_kann_node (L, 1);
341 	gint nnodes = luaL_checkinteger (L, 2);
342 
343 	if (in != NULL && nnodes > 0) {
344 		kad_node_t *t;
345 
346 		t = kann_layer_dense (in, nnodes);
347 
348 		PROCESS_KAD_FLAGS (t, 3);
349 		PUSH_KAD_NODE (t);
350 	}
351 	else {
352 		return luaL_error (L, "invalid arguments, input + nnodes required");
353 	}
354 
355 	return 1;
356 }
357 
358 /***
359  * @function kann.layer.dropout(in, ratio[, flags])
360  * Creates a dropout layer
361  * @param {kann_node} in kann node
362  * @param {float} ratio drop ratio
363  * @param {table|int} flags optional flags
364  * @return {kann_node} kann node object (should be used to combine ANN)
365 */
366 static int
lua_kann_layer_layerdropout(lua_State * L)367 lua_kann_layer_layerdropout (lua_State *L)
368 {
369 	kad_node_t *in = lua_check_kann_node (L, 1);
370 	double r = luaL_checknumber (L, 2);
371 
372 	if (in != NULL) {
373 		kad_node_t *t;
374 
375 		t = kann_layer_dropout (in, r);
376 
377 		PROCESS_KAD_FLAGS (t, 3);
378 		PUSH_KAD_NODE (t);
379 	}
380 	else {
381 		return luaL_error (L, "invalid arguments, input + rate required");
382 	}
383 
384 	return 1;
385 }
386 
387 /***
388  * @function kann.layer.dropout(in [, flags])
389  * Creates a normalisation layer
390  * @param {kann_node} in kann node
391  * @param {table|int} flags optional flags
392  * @return {kann_node} kann node object (should be used to combine ANN)
393 */
394 static int
lua_kann_layer_layernorm(lua_State * L)395 lua_kann_layer_layernorm (lua_State *L)
396 {
397 	kad_node_t *in = lua_check_kann_node (L, 1);
398 
399 	if (in != NULL) {
400 		kad_node_t *t;
401 
402 		t = kann_layer_layernorm (in);
403 
404 		PROCESS_KAD_FLAGS (t, 2);
405 		PUSH_KAD_NODE (t);
406 	}
407 	else {
408 		return luaL_error (L, "invalid arguments, input required");
409 	}
410 
411 	return 1;
412 }
413 
414 /***
415  * @function kann.layer.rnn(in, nnodes[, rnn_flags, [, flags]])
416  * Creates a recursive NN layer
417  * @param {kann_node} in kann node
418  * @param {int} nnodes number of cells
419  * @param {int} rnnflags rnn flags
420  * @param {table|int} flags optional flags
421  * @return {kann_node} kann node object (should be used to combine ANN)
422 */
423 static int
lua_kann_layer_rnn(lua_State * L)424 lua_kann_layer_rnn (lua_State *L)
425 {
426 	kad_node_t *in = lua_check_kann_node (L, 1);
427 	gint nnodes = luaL_checkinteger (L, 2);
428 	gint rnnflags = 0;
429 
430 	if (in != NULL && nnodes > 0) {
431 		kad_node_t *t;
432 
433 		if (lua_type (L, 3) == LUA_TNUMBER) {
434 			rnnflags = lua_tointeger (L, 3);
435 		}
436 
437 		t = kann_layer_rnn (in, nnodes, rnnflags);
438 
439 		PROCESS_KAD_FLAGS (t, 4);
440 		PUSH_KAD_NODE (t);
441 	}
442 	else {
443 		return luaL_error (L, "invalid arguments, input + nnodes required");
444 	}
445 
446 	return 1;
447 }
448 
449 /***
450  * @function kann.layer.lstm(in, nnodes[, rnn_flags, [, flags]])
451  * Creates a recursive NN layer using LSTM cells
452  * @param {kann_node} in kann node
453  * @param {int} nnodes number of cells
454  * @param {int} rnnflags rnn flags
455  * @param {table|int} flags optional flags
456  * @return {kann_node} kann node object (should be used to combine ANN)
457 */
458 static int
lua_kann_layer_lstm(lua_State * L)459 lua_kann_layer_lstm (lua_State *L)
460 {
461 	kad_node_t *in = lua_check_kann_node (L, 1);
462 	gint nnodes = luaL_checkinteger (L, 2);
463 	gint rnnflags = 0;
464 
465 	if (in != NULL && nnodes > 0) {
466 		kad_node_t *t;
467 
468 		if (lua_type (L, 3) == LUA_TNUMBER) {
469 			rnnflags = lua_tointeger (L, 3);
470 		}
471 
472 		t = kann_layer_lstm (in, nnodes, rnnflags);
473 
474 		PROCESS_KAD_FLAGS (t, 4);
475 		PUSH_KAD_NODE (t);
476 	}
477 	else {
478 		return luaL_error (L, "invalid arguments, input + nnodes required");
479 	}
480 
481 	return 1;
482 }
483 
484 /***
485  * @function kann.layer.rnn(in, nnodes[, rnn_flags, [, flags]])
486  * Creates a recursive NN layer using GRU cells
487  * @param {kann_node} in kann node
488  * @param {int} nnodes number of cells
489  * @param {int} rnnflags rnn flags
490  * @param {table|int} flags optional flags
491  * @return {kann_node} kann node object (should be used to combine ANN)
492 */
493 static int
lua_kann_layer_gru(lua_State * L)494 lua_kann_layer_gru (lua_State *L)
495 {
496 	kad_node_t *in = lua_check_kann_node (L, 1);
497 	gint nnodes = luaL_checkinteger (L, 2);
498 	gint rnnflags = 0;
499 
500 	if (in != NULL && nnodes > 0) {
501 		kad_node_t *t;
502 
503 		if (lua_type (L, 3) == LUA_TNUMBER) {
504 			rnnflags = lua_tointeger (L, 3);
505 		}
506 
507 		t = kann_layer_gru (in, nnodes, rnnflags);
508 
509 		PROCESS_KAD_FLAGS (t, 4);
510 		PUSH_KAD_NODE (t);
511 	}
512 	else {
513 		return luaL_error (L, "invalid arguments, input + nnodes required");
514 	}
515 
516 	return 1;
517 }
518 
519 /***
520  * @function kann.layer.conv2d(in, n_flt, k_rows, k_cols, stride_rows, stride_cols, pad_rows, pad_columns[, flags])
521  * Creates a 2D convolution layer
522  * @param {kann_node} in kann node
523  * @param {int} n_flt number of filters
524  * @param {int} k_rows kernel rows
525  * @param {int} k_cols kernel columns
526  * @param {int} stride_rows stride rows
527  * @param {int} stride_cols stride columns
528  * @param {int} pad_rows padding rows
529  * @param {int} pad_columns padding columns
530  * @param {table|int} flags optional flags
531  * @return {kann_node} kann node object (should be used to combine ANN)
532 */
533 static int
lua_kann_layer_conv2d(lua_State * L)534 lua_kann_layer_conv2d (lua_State *L)
535 {
536 	kad_node_t *in = lua_check_kann_node (L, 1);
537 	int n_flt = luaL_checkinteger (L, 2);
538 	int k_rows = luaL_checkinteger (L, 3);
539 	int k_cols =  luaL_checkinteger (L, 4);
540 	int stride_r = luaL_checkinteger (L, 5);
541 	int stride_c = luaL_checkinteger (L, 6);
542 	int pad_r = luaL_checkinteger (L, 7);
543 	int pad_c = luaL_checkinteger (L, 8);
544 
545 	if (in != NULL) {
546 		kad_node_t *t;
547 		t = kann_layer_conv2d (in, n_flt, k_rows, k_cols, stride_r, stride_c,
548 				pad_r, pad_c);
549 
550 		PROCESS_KAD_FLAGS (t, 9);
551 		PUSH_KAD_NODE (t);
552 	}
553 	else {
554 		return luaL_error (L, "invalid arguments, input, nflt, kx, ky, stridex, stridey, padx, pady are required");
555 	}
556 
557 	return 1;
558 }
559 
560 /***
561  * @function kann.layer.conv1d(in, n_flt, kern_size, stride_size, pad_size[, flags])
562  * Creates 1D convolution layer
563  * @param {kann_node} in kann node
564  * @param {int} n_flt number of filters
565  * @param {int} kern_size kernel rows
566  * @param {int} stride_size stride rows
567  * @param {int} pad_size padding rows
568  * @param {table|int} flags optional flags
569  * @return {kann_node} kann node object (should be used to combine ANN)
570 */
571 static int
lua_kann_layer_conv1d(lua_State * L)572 lua_kann_layer_conv1d (lua_State *L)
573 {
574 	kad_node_t *in = lua_check_kann_node (L, 1);
575 	int n_flt = luaL_checkinteger (L, 2);
576 	int k_size = luaL_checkinteger (L, 3);
577 	int stride = luaL_checkinteger (L, 4);
578 	int pad = luaL_checkinteger (L, 5);
579 
580 	if (in != NULL) {
581 		kad_node_t *t;
582 		t = kann_layer_conv1d (in, n_flt, k_size, stride, pad);
583 
584 		PROCESS_KAD_FLAGS (t, 6);
585 		PUSH_KAD_NODE (t);
586 	}
587 	else {
588 		return luaL_error (L, "invalid arguments, input, nflt, k, stride, pad required");
589 	}
590 
591 	return 1;
592 }
593 
594 /***
595  * @function kann.layer.cost(in, nout, cost_type[, flags])
596  * Creates 1D convolution layer
597  * @param {kann_node} in kann node
598  * @param {int} nout number of outputs
599  * @param {int} cost_type see kann.cost table
600  * @param {table|int} flags optional flags
601  * @return {kann_node} kann node object (should be used to combine ANN)
602 */
603 static int
lua_kann_layer_cost(lua_State * L)604 lua_kann_layer_cost (lua_State *L)
605 {
606 	kad_node_t *in = lua_check_kann_node (L, 1);
607 	int nout = luaL_checkinteger (L, 2);
608 	int cost_type = luaL_checkinteger (L, 3);
609 
610 	if (in != NULL && nout > 0) {
611 		kad_node_t *t;
612 		t = kann_layer_cost (in, nout, cost_type);
613 
614 		PROCESS_KAD_FLAGS (t, 4);
615 		PUSH_KAD_NODE (t);
616 	}
617 	else {
618 		return luaL_error (L, "invalid arguments, input, nout and cost_type are required");
619 	}
620 
621 	return 1;
622 }
623 
624 /* Generic helpers */
625 static int
lua_kann_call_unary_function(lua_State * L,const char * name,kad_node_t * (* func)(kad_node_t *))626 lua_kann_call_unary_function (lua_State *L, const char *name,
627 		kad_node_t *(*func)(kad_node_t *))
628 {
629 	kad_node_t *in = lua_check_kann_node (L, 1);
630 
631 	if (in != NULL) {
632 		kad_node_t *t;
633 		t = func (in);
634 
635 		PUSH_KAD_NODE (t);
636 	}
637 	else {
638 		return luaL_error (L, "invalid arguments for %s, input required", name);
639 	}
640 
641 	return 1;
642 }
643 static int
lua_kann_call_binary_function(lua_State * L,const char * name,kad_node_t * (* func)(kad_node_t *,kad_node_t *))644 lua_kann_call_binary_function (lua_State *L, const char *name,
645 							  kad_node_t *(*func)(kad_node_t *, kad_node_t *))
646 {
647 	kad_node_t *x = lua_check_kann_node (L, 1);
648 	kad_node_t *y = lua_check_kann_node (L, 2);
649 
650 	if (x != NULL && y != NULL) {
651 		kad_node_t *t;
652 		t = func (x, y);
653 
654 		PUSH_KAD_NODE (t);
655 	}
656 	else {
657 		return luaL_error (L, "invalid arguments for %s, 2 inputs required", name);
658 	}
659 
660 	return 1;
661 }
662 
663 #define LUA_UNARY_TRANSFORM_FUNC_IMPL(name)									\
664 static int lua_kann_transform_ ##name (lua_State *L)						\
665 {																			\
666 	return lua_kann_call_unary_function(L, #name, kad_##name);				\
667 }
668 
669 #define LUA_BINARY_TRANSFORM_FUNC_IMPL(name)								\
670 static int lua_kann_transform_ ##name (lua_State *L)						\
671 {																			\
672 	return lua_kann_call_binary_function(L, #name, kad_##name);				\
673 }
674 
675 #define LUA_LOSS_FUNC_IMPL(name)											\
676 static int lua_kann_loss_ ##name (lua_State *L)								\
677 {																			\
678 	return lua_kann_call_binary_function(L, #name, kad_##name);				\
679 }
680 
681 /* Transform functions registered via macro helpers */
682 LUA_BINARY_TRANSFORM_FUNC_IMPL (add)
LUA_BINARY_TRANSFORM_FUNC_IMPL(sub)683 LUA_BINARY_TRANSFORM_FUNC_IMPL (sub)
684 LUA_BINARY_TRANSFORM_FUNC_IMPL (mul)
685 LUA_BINARY_TRANSFORM_FUNC_IMPL (cmul)
686 LUA_BINARY_TRANSFORM_FUNC_IMPL (matmul)
687 
688 LUA_UNARY_TRANSFORM_FUNC_IMPL (square)
689 LUA_UNARY_TRANSFORM_FUNC_IMPL (sigm)
690 LUA_UNARY_TRANSFORM_FUNC_IMPL (tanh)
691 LUA_UNARY_TRANSFORM_FUNC_IMPL (relu)
692 LUA_UNARY_TRANSFORM_FUNC_IMPL (softmax)
693 LUA_UNARY_TRANSFORM_FUNC_IMPL (1minus)
694 LUA_UNARY_TRANSFORM_FUNC_IMPL (exp)
695 LUA_UNARY_TRANSFORM_FUNC_IMPL (log)
696 LUA_UNARY_TRANSFORM_FUNC_IMPL (sin)
697 
698 /* Generic cost functions */
699 LUA_LOSS_FUNC_IMPL (mse)
700 LUA_LOSS_FUNC_IMPL (ce_multi)
701 LUA_LOSS_FUNC_IMPL (ce_bin)
702 LUA_LOSS_FUNC_IMPL (ce_bin_neg)
703 
704 /* The only case of ternary weight function */
705 static int
706 lua_kann_loss_ce_multi_weighted (lua_State *L)
707 {
708 	kad_node_t *pred = lua_check_kann_node (L, 1);
709 	kad_node_t *truth = lua_check_kann_node (L, 2);
710 	kad_node_t *weight = lua_check_kann_node (L, 3);
711 
712 	if (pred != NULL && truth != NULL && weight != NULL) {
713 		kad_node_t *t;
714 		t = kad_ce_multi_weighted (pred, truth, weight);
715 
716 		PUSH_KAD_NODE (t);
717 	}
718 	else {
719 		return luaL_error (L, "invalid arguments for ce_multi_weighted, 3 inputs required");
720 	}
721 
722 	return 1;
723 }
724 
725 /* Creation functions */
726 static int
lua_kann_new_scalar(lua_State * L)727 lua_kann_new_scalar (lua_State *L)
728 {
729 	gint flag = luaL_checkinteger (L, 1);
730 	double x = luaL_checknumber (L, 2);
731 	kad_node_t *t;
732 
733 	t = kann_new_scalar (flag, x);
734 
735 	PROCESS_KAD_FLAGS (t, 3);
736 	PUSH_KAD_NODE (t);
737 
738 	return 1;
739 }
740 
741 static int
lua_kann_new_weight(lua_State * L)742 lua_kann_new_weight (lua_State *L)
743 {
744 	gint nrow = luaL_checkinteger (L, 1);
745 	gint ncol = luaL_checkinteger (L, 2);
746 	kad_node_t *t;
747 
748 	t = kann_new_weight (nrow, ncol);
749 
750 	PROCESS_KAD_FLAGS (t, 3);
751 	PUSH_KAD_NODE (t);
752 
753 	return 1;
754 }
755 
756 static int
lua_kann_new_bias(lua_State * L)757 lua_kann_new_bias (lua_State *L)
758 {
759 	gint n = luaL_checkinteger (L, 1);
760 	kad_node_t *t;
761 
762 	t = kann_new_bias (n);
763 
764 	PROCESS_KAD_FLAGS (t, 2);
765 	PUSH_KAD_NODE (t);
766 
767 	return 1;
768 }
769 
770 static int
lua_kann_new_weight_conv2d(lua_State * L)771 lua_kann_new_weight_conv2d (lua_State *L)
772 {
773 	gint nout = luaL_checkinteger (L, 1);
774 	gint nin = luaL_checkinteger (L, 2);
775 	gint krow = luaL_checkinteger (L, 3);
776 	gint kcol = luaL_checkinteger (L, 4);
777 	kad_node_t *t;
778 
779 	t = kann_new_weight_conv2d (nout, nin, krow, kcol);
780 
781 	PROCESS_KAD_FLAGS (t, 5);
782 	PUSH_KAD_NODE (t);
783 
784 	return 1;
785 }
786 
787 static int
lua_kann_new_weight_conv1d(lua_State * L)788 lua_kann_new_weight_conv1d (lua_State *L)
789 {
790 	gint nout = luaL_checkinteger (L, 1);
791 	gint nin = luaL_checkinteger (L, 2);
792 	gint klen = luaL_checkinteger (L, 3);
793 	kad_node_t *t;
794 
795 	t = kann_new_weight_conv1d (nout, nin, klen);
796 
797 	PROCESS_KAD_FLAGS (t, 4);
798 	PUSH_KAD_NODE (t);
799 
800 	return 1;
801 }
802 
803 static int
lua_kann_new_leaf(lua_State * L)804 lua_kann_new_leaf (lua_State *L)
805 {
806 	int dim = luaL_checkinteger (L, 1), i, *ar;
807 	kad_node_t *t;
808 
809 	if (dim >= 1 && dim < KAD_MAX_DIM && lua_istable (L, 2)) {
810 		ar = g_new0 (int, dim);
811 
812 		for (i = 0; i < dim; i ++) {
813 			lua_rawgeti (L, 2, i + 1);
814 			ar[i] = lua_tointeger (L, -1);
815 			lua_pop (L, 1);
816 		}
817 
818 		t = kann_new_leaf_array (NULL, NULL, 0, 0.0, dim, ar);
819 
820 		PROCESS_KAD_FLAGS (t, 3);
821 		PUSH_KAD_NODE (t);
822 
823 		g_free (ar);
824 	}
825 	else {
826 		return luaL_error (L, "invalid arguments for new.leaf, "
827 						"dim and vector of elements are required");
828 	}
829 
830 	return 1;
831 }
832 
833 static int
lua_kann_new_kann(lua_State * L)834 lua_kann_new_kann (lua_State *L)
835 {
836 	kad_node_t *cost = lua_check_kann_node (L, 1);
837 	kann_t *k;
838 
839 	if (cost) {
840 		k = kann_new (cost, 0);
841 
842 		PUSH_KAN_NETWORK (k);
843 	}
844 	else {
845 		return luaL_error (L, "invalid arguments for new.kann, "
846 							  "cost node is required");
847 	}
848 
849 	return 1;
850 }
851 
852 static int
lua_kann_destroy(lua_State * L)853 lua_kann_destroy (lua_State *L)
854 {
855 	kann_t *k = lua_check_kann (L, 1);
856 
857 	kann_delete (k);
858 
859 	return 0;
860 }
861 
862 static int
lua_kann_save(lua_State * L)863 lua_kann_save (lua_State *L)
864 {
865 	kann_t *k = lua_check_kann (L, 1);
866 
867 	if (k) {
868 		if (lua_istable (L, 2)) {
869 			lua_getfield (L, 2, "filename");
870 
871 			if (lua_isstring (L, -1)) {
872 				const gchar *fname = lua_tostring (L, -1);
873 				FILE *f;
874 
875 				f = fopen (fname, "w");
876 
877 				if (!f) {
878 					lua_pop (L, 1);
879 
880 					return luaL_error (L, "cannot open %s for writing: %s",
881 							fname, strerror (errno));
882 				}
883 
884 				kann_save_fp (f, k);
885 				fclose (f);
886 
887 				lua_pushboolean (L, true);
888 			}
889 			else {
890 				lua_pop (L, 1);
891 
892 				return luaL_error (L, "invalid arguments: missing filename");
893 			}
894 
895 			lua_pop (L, 1);
896 		}
897 		else {
898 			/* Save to Rspamd text */
899 #ifndef HAVE_OPENMEMSTREAM
900 			return luaL_error (L, "no support of saving to memory on your system");
901 #endif
902 			FILE *f;
903 			char *buf = NULL;
904 			size_t buflen;
905 			struct rspamd_lua_text *t;
906 
907 			f = open_memstream (&buf, &buflen);
908 			g_assert (f != NULL);
909 
910 			kann_save_fp (f, k);
911 			fclose (f);
912 
913 			t = lua_newuserdata (L, sizeof (*t));
914 			rspamd_lua_setclass (L, "rspamd{text}", -1);
915 			t->flags = RSPAMD_TEXT_FLAG_OWN;
916 			t->start = (const gchar *)buf;
917 			t->len = buflen;
918 		}
919 	}
920 	else {
921 		return luaL_error (L, "invalid arguments");
922 	}
923 
924 	return 1;
925 }
926 
927 static int
lua_kann_load(lua_State * L)928 lua_kann_load (lua_State *L)
929 {
930 	kann_t *k;
931 	FILE *f = NULL;
932 
933 	if (lua_istable (L, 1)) {
934 		lua_getfield (L, 2, "filename");
935 
936 		if (lua_isstring (L, -1)) {
937 			const gchar *fname = lua_tostring (L, -1);
938 
939 			f = fopen (fname, "rb");
940 		}
941 		else {
942 			lua_pop (L, 1);
943 
944 			return luaL_error (L, "invalid arguments: missing filename");
945 		}
946 
947 		lua_pop (L, 1);
948 	}
949 	else if (lua_isstring (L, 1)) {
950 		gsize dlen;
951 		const gchar *data;
952 
953 		data = lua_tolstring (L, 1, &dlen);
954 
955 #ifndef HAVE_FMEMOPEN
956 		return luaL_error (L, "no support of loading from memory on your system");
957 #endif
958 		f = fmemopen ((void *)data, dlen, "rb");
959 	}
960 	else if (lua_isuserdata (L, 1)) {
961 		struct rspamd_lua_text *t;
962 
963 		t = lua_check_text (L, 1);
964 
965 		if (!t) {
966 			return luaL_error (L, "invalid arguments");
967 		}
968 
969 #ifndef HAVE_FMEMOPEN
970 		return luaL_error (L, "no support of loading from memory on your system");
971 #endif
972 		f = fmemopen ((void *)t->start, t->len, "rb");
973 	}
974 
975 	if (f == NULL) {
976 		return luaL_error (L, "invalid arguments or cannot open file");
977 	}
978 
979 	k = kann_load_fp (f);
980 	fclose (f);
981 
982 	if (k == NULL) {
983 		lua_pushnil (L);
984 	}
985 	else {
986 		PUSH_KAN_NETWORK (k);
987 	}
988 
989 	return 1;
990 }
991 
992 struct rspamd_kann_train_cbdata {
993 	lua_State *L;
994 	kann_t *k;
995 	gint cbref;
996 };
997 
998 static void
lua_kann_train_cb(int iter,float train_cost,float val_cost,void * ud)999 lua_kann_train_cb (int iter, float train_cost, float val_cost, void *ud)
1000 {
1001 	struct rspamd_kann_train_cbdata *cbd = (struct rspamd_kann_train_cbdata *)ud;
1002 
1003 	if (cbd->cbref != -1) {
1004 		gint err_idx;
1005 		lua_State *L = cbd->L;
1006 
1007 		lua_pushcfunction (L, &rspamd_lua_traceback);
1008 		err_idx = lua_gettop (L);
1009 
1010 		lua_rawgeti (L, LUA_REGISTRYINDEX, cbd->cbref);
1011 		lua_pushinteger (L, iter);
1012 		lua_pushnumber (L, train_cost);
1013 		lua_pushnumber (L, val_cost);
1014 
1015 		if (lua_pcall (L, 3, 0, err_idx) != 0) {
1016 			msg_err ("cannot run lua train callback: %s",
1017 					lua_tostring (L, -1));
1018 		}
1019 
1020 		lua_settop (L, err_idx - 1);
1021 	}
1022 }
1023 
1024 #define FREE_VEC(a, n) do { for(int i = 0; i < (n); i ++) g_free((a)[i]); g_free(a); } while(0)
1025 
1026 static int
lua_kann_train1(lua_State * L)1027 lua_kann_train1 (lua_State *L)
1028 {
1029 	kann_t *k = lua_check_kann (L, 1);
1030 	struct rspamd_lua_tensor *pca = NULL;
1031 
1032 	/* Default train params */
1033 	double lr = 0.001;
1034 	gint64 mini_size = 64;
1035 	gint64 max_epoch = 25;
1036 	gint64 max_drop_streak = 10;
1037 	double frac_val = 0.1;
1038 	gint cbref = -1;
1039 
1040 	if (k && lua_istable (L, 2) && lua_istable (L, 3)) {
1041 		int n = rspamd_lua_table_size (L, 2);
1042 		int n_in = kann_dim_in (k);
1043 		int n_out = kann_dim_out (k);
1044 
1045 		if (n_in <= 0) {
1046 			return luaL_error (L, "invalid inputs count: %d", n_in);
1047 		}
1048 
1049 		if (n_out <= 0) {
1050 			return luaL_error (L, "invalid outputs count: %d", n_out);
1051 		}
1052 
1053 		if (n != rspamd_lua_table_size (L, 3) || n == 0) {
1054 			return luaL_error (L, "invalid dimensions: outputs size must be "
1055 						 "equal to inputs and non zero");
1056 		}
1057 
1058 		if (lua_istable (L, 4)) {
1059 			GError *err = NULL;
1060 
1061 			if (!rspamd_lua_parse_table_arguments (L, 4, &err,
1062 					RSPAMD_LUA_PARSE_ARGUMENTS_IGNORE_MISSING,
1063 					"lr=N;mini_size=I;max_epoch=I;max_drop_streak=I;frac_val=N;cb=F;pca=u{tensor}",
1064 					&lr, &mini_size, &max_epoch, &max_drop_streak, &frac_val, &cbref, &pca)) {
1065 				n = luaL_error (L, "invalid params: %s",
1066 						err ? err->message : "unknown error");
1067 				g_error_free (err);
1068 
1069 				return n;
1070 			}
1071 		}
1072 
1073 		if (pca) {
1074 			/* Check pca matrix validity */
1075 			if (pca->ndims != 2) {
1076 				return luaL_error (L, "invalid pca tensor: matrix expected, got a row");
1077 			}
1078 
1079 			if (pca->dim[0] != n_in) {
1080 				return luaL_error (L, "invalid pca tensor: "
1081 						  "matrix must have %d rows and it has %d rows instead",
1082 						  n_in, pca->dim[0]);
1083 			}
1084 		}
1085 
1086 		float **x, **y, *tmp_row = NULL;
1087 
1088 		/* Fill vectors row by row */
1089 		x = (float **)g_malloc0 (sizeof (float *) * n);
1090 		y = (float **)g_malloc0 (sizeof (float *) * n);
1091 
1092 		if (pca) {
1093 			tmp_row = g_malloc (sizeof (float) * pca->dim[1]);
1094 		}
1095 
1096 		for (int s = 0; s < n; s ++) {
1097 			/* Inputs */
1098 			lua_rawgeti (L, 2, s + 1);
1099 			x[s] = (float *)g_malloc (sizeof (float) * n_in);
1100 
1101 			if (pca == NULL) {
1102 				if (rspamd_lua_table_size (L, -1) != n_in) {
1103 					FREE_VEC (x, n);
1104 					FREE_VEC (y, n);
1105 
1106 					n = luaL_error (L, "invalid params at pos %d: "
1107 									   "bad input dimension %d; %d expected",
1108 							s + 1,
1109 							(int) rspamd_lua_table_size (L, -1),
1110 							n_in);
1111 					lua_pop (L, 1);
1112 
1113 					return n;
1114 				}
1115 
1116 				for (int i = 0; i < n_in; i++) {
1117 					lua_rawgeti (L, -1, i + 1);
1118 					x[s][i] = lua_tonumber (L, -1);
1119 
1120 					lua_pop (L, 1);
1121 				}
1122 			}
1123 			else {
1124 				if (rspamd_lua_table_size (L, -1) != pca->dim[1]) {
1125 					FREE_VEC (x, n);
1126 					FREE_VEC (y, n);
1127 					g_free (tmp_row);
1128 
1129 					n = luaL_error (L, "(pca on) invalid params at pos %d: "
1130 									   "bad input dimension %d; %d expected",
1131 							s + 1,
1132 							(int) rspamd_lua_table_size (L, -1),
1133 							pca->dim[1]);
1134 					lua_pop (L, 1);
1135 
1136 					return n;
1137 				}
1138 
1139 
1140 				for (int i = 0; i < pca->dim[1]; i++) {
1141 					lua_rawgeti (L, -1, i + 1);
1142 					tmp_row[i] = lua_tonumber (L, -1);
1143 
1144 					lua_pop (L, 1);
1145 				}
1146 
1147 				kad_sgemm_simple (0, 1, 1, n_in,
1148 						pca->dim[1], tmp_row, pca->data,
1149 						x[s]);
1150 			}
1151 
1152 			lua_pop (L, 1);
1153 
1154 			/* Outputs */
1155 			y[s] = (float *)g_malloc (sizeof (float) * n_out);
1156 			lua_rawgeti (L, 3, s + 1);
1157 
1158 			if (rspamd_lua_table_size (L, -1) != n_out) {
1159 				FREE_VEC (x, n);
1160 				FREE_VEC (y, n);
1161 				g_free (tmp_row);
1162 
1163 				n = luaL_error (L, "invalid params at pos %d: "
1164 					   "bad output dimension %d; "
1165 					   "%d expected",
1166 						s + 1,
1167 						(int)rspamd_lua_table_size (L, -1),
1168 						n_out);
1169 				lua_pop (L, 1);
1170 
1171 				return n;
1172 			}
1173 
1174 			for (int i = 0; i < n_out; i ++) {
1175 				lua_rawgeti (L, -1, i + 1);
1176 				y[s][i] = lua_tonumber (L, -1);
1177 
1178 				lua_pop (L, 1);
1179 			}
1180 
1181 			lua_pop (L, 1);
1182 		}
1183 
1184 		struct rspamd_kann_train_cbdata cbd;
1185 
1186 		cbd.cbref = cbref;
1187 		cbd.k = k;
1188 		cbd.L = L;
1189 
1190 		int niters = kann_train_fnn1 (k, lr,
1191 				mini_size, max_epoch, max_drop_streak,
1192 				frac_val, n, x, y, lua_kann_train_cb, &cbd);
1193 
1194 		lua_pushinteger (L, niters);
1195 
1196 		FREE_VEC (x, n);
1197 		FREE_VEC (y, n);
1198 		g_free (tmp_row);
1199 	}
1200 	else {
1201 		return luaL_error (L, "invalid arguments: kann, inputs, outputs and"
1202 							  " optional params are expected");
1203 	}
1204 
1205 	return 1;
1206 }
1207 
1208 static int
lua_kann_apply1(lua_State * L)1209 lua_kann_apply1 (lua_State *L)
1210 {
1211 	kann_t *k = lua_check_kann (L, 1);
1212 	struct rspamd_lua_tensor *pca = NULL;
1213 
1214 	if (k) {
1215 		if (lua_istable (L, 2)) {
1216 			gsize vec_len = rspamd_lua_table_size (L, 2);
1217 			float *vec = (float *) g_malloc (sizeof (float) * vec_len),
1218 				*pca_out = NULL;
1219 			int i_out;
1220 			int n_in = kann_dim_in (k);
1221 
1222 			if (n_in <= 0) {
1223 				g_free (vec);
1224 				return luaL_error (L, "invalid inputs count: %d", n_in);
1225 			}
1226 
1227 			if (lua_isuserdata (L, 3)) {
1228 				pca = lua_check_tensor (L, 3);
1229 
1230 				if (pca) {
1231 					if (pca->ndims != 2) {
1232 						g_free (vec);
1233 						return luaL_error (L, "invalid pca tensor: matrix expected, got a row");
1234 					}
1235 
1236 					if (pca->dim[0] != n_in) {
1237 						g_free (vec);
1238 						return luaL_error (L, "invalid pca tensor: "
1239 											  "matrix must have %d rows and it has %d rows instead",
1240 								n_in, pca->dim[0]);
1241 					}
1242 				}
1243 				else {
1244 					g_free (vec);
1245 					return luaL_error (L, "invalid params: pca matrix expected");
1246 				}
1247 			}
1248 			else {
1249 				if (n_in != vec_len) {
1250 					g_free (vec);
1251 					return luaL_error (L, "invalid params: bad input dimension %d; %d expected",
1252 							(int) vec_len, n_in);
1253 				}
1254 			}
1255 
1256 			for (gsize i = 0; i < vec_len; i++) {
1257 				lua_rawgeti (L, 2, i + 1);
1258 				vec[i] = lua_tonumber (L, -1);
1259 				lua_pop (L, 1);
1260 			}
1261 
1262 			i_out = kann_find (k, KANN_F_OUT, 0);
1263 
1264 			if (i_out <= 0) {
1265 				g_free (vec);
1266 				return luaL_error (L, "invalid ANN: output layer is missing or is "
1267 									  "at the input pos");
1268 			}
1269 
1270 			kann_set_batch_size (k, 1);
1271 			if (pca) {
1272 				pca_out = g_malloc (sizeof (float) * n_in);
1273 
1274 				kad_sgemm_simple (0, 1, 1, n_in,
1275 						vec_len, vec, pca->data,
1276 						pca_out);
1277 
1278 				kann_feed_bind (k, KANN_F_IN, 0, &pca_out);
1279 			}
1280 			else {
1281 				kann_feed_bind (k, KANN_F_IN, 0, &vec);
1282 			}
1283 
1284 			kad_eval_at (k->n, k->v, i_out);
1285 
1286 			gsize outlen = kad_len (k->v[i_out]);
1287 			lua_createtable (L, outlen, 0);
1288 
1289 			for (gsize i = 0; i < outlen; i++) {
1290 				lua_pushnumber (L, k->v[i_out]->x[i]);
1291 				lua_rawseti (L, -2, i + 1);
1292 			}
1293 
1294 			g_free (vec);
1295 			g_free (pca_out);
1296 		}
1297 		else if (lua_isuserdata (L, 2)) {
1298 			struct rspamd_lua_tensor *t = lua_check_tensor (L, 2);
1299 
1300 			if (t && t->ndims == 1) {
1301 				int i_out;
1302 				int n_in = kann_dim_in (k);
1303 
1304 				if (n_in != t->dim[0]) {
1305 					return luaL_error (L, "invalid params: bad input dimension %d; %d expected",
1306 							(int) t->dim[0], n_in);
1307 				}
1308 
1309 				i_out = kann_find (k, KANN_F_OUT, 0);
1310 
1311 				if (i_out <= 0) {
1312 					return luaL_error (L, "invalid ANN: output layer is missing or is "
1313 										  "at the input pos");
1314 				}
1315 
1316 				kann_set_batch_size (k, 1);
1317 				kann_feed_bind (k, KANN_F_IN, 0, &t->data);
1318 				kad_eval_at (k->n, k->v, i_out);
1319 
1320 				gint outlen = kad_len (k->v[i_out]);
1321 				struct rspamd_lua_tensor *out;
1322 				out = lua_newtensor (L, 1, &outlen, false, false);
1323 				/* Ensure that kann and tensor have the same understanding of floats */
1324 				G_STATIC_ASSERT (sizeof (float) == sizeof (rspamd_tensor_num_t));
1325 				memcpy (out->data, k->v[i_out]->x, outlen * sizeof (float));
1326 			}
1327 			else {
1328 				return luaL_error (L, "invalid arguments: 1D rspamd{tensor} expected");
1329 			}
1330 		}
1331 		else {
1332 			return luaL_error (L, "invalid arguments: 1D rspamd{tensor} expected");
1333 		}
1334 	}
1335 	else {
1336 		return luaL_error (L, "invalid arguments: rspamd{kann} expected");
1337 	}
1338 
1339 	return 1;
1340 }