1 #include "config.h"
2 
3 #include <math.h>
4 #include <float.h>
5 #include <string.h>
6 #include <stdlib.h>
7 #include <assert.h>
8 #include <stdarg.h>
9 #include "kann.h"
10 
11 int kann_verbose = 3;
12 
13 /******************************************
14  *** @@BASIC: fundamental KANN routines ***
15  ******************************************/
16 
kad_ext_collate(int n,kad_node_t ** a,float ** _x,float ** _g,float ** _c)17 static void kad_ext_collate(int n, kad_node_t **a, float **_x, float **_g, float **_c)
18 {
19 	int i, j, k, l, n_var;
20 	float *x, *g, *c;
21 	n_var = kad_size_var(n, a);
22 	x = *_x = (float*)realloc(*_x, n_var * sizeof(float));
23 	g = *_g = (float*)realloc(*_g, n_var * sizeof(float));
24 	c = *_c = (float*)realloc(*_c, kad_size_const(n, a) * sizeof(float));
25 	memset(g, 0, n_var * sizeof(float));
26 	for (i = j = k = 0; i < n; ++i) {
27 		kad_node_t *v = a[i];
28 		if (kad_is_var(v)) {
29 			l = kad_len(v);
30 			memcpy(&x[j], v->x, l * sizeof(float));
31 			free(v->x);
32 			v->x = &x[j];
33 			v->g = &g[j];
34 			j += l;
35 		} else if (kad_is_const(v)) {
36 			l = kad_len(v);
37 			memcpy(&c[k], v->x, l * sizeof(float));
38 			free(v->x);
39 			v->x = &c[k];
40 			k += l;
41 		}
42 	}
43 }
44 
kad_ext_sync(int n,kad_node_t ** a,float * x,float * g,float * c)45 static void kad_ext_sync(int n, kad_node_t **a, float *x, float *g, float *c)
46 {
47 	int i, j, k;
48 	for (i = j = k = 0; i < n; ++i) {
49 		kad_node_t *v = a[i];
50 		if (kad_is_var(v)) {
51 			v->x = &x[j];
52 			v->g = &g[j];
53 			j += kad_len(v);
54 		} else if (kad_is_const(v)) {
55 			v->x = &c[k];
56 			k += kad_len(v);
57 		}
58 	}
59 }
60 
kann_new(kad_node_t * cost,int n_rest,...)61 kann_t *kann_new(kad_node_t *cost, int n_rest, ...)
62 {
63 	kann_t *a;
64 	int i, n_roots = 1 + n_rest, has_pivot = 0, has_recur = 0;
65 	kad_node_t **roots;
66 	va_list ap;
67 
68 	if (cost->n_d != 0) return 0;
69 
70 	va_start(ap, n_rest);
71 	roots = (kad_node_t**)malloc((n_roots + 1) * sizeof(kad_node_t*));
72 	for (i = 0; i < n_rest; ++i)
73 		roots[i] = va_arg(ap, kad_node_t*);
74 	roots[i++] = cost;
75 	va_end(ap);
76 
77 	cost->ext_flag |= KANN_F_COST;
78 	a = (kann_t*)calloc(1, sizeof(kann_t));
79 	a->v = kad_compile_array(&a->n, n_roots, roots);
80 
81 	for (i = 0; i < a->n; ++i) {
82 		if (a->v[i]->pre) has_recur = 1;
83 		if (kad_is_pivot(a->v[i])) has_pivot = 1;
84 	}
85 	if (has_recur && !has_pivot) { /* an RNN that doesn't have a pivot; then add a pivot on top of cost and recompile */
86 		cost->ext_flag &= ~KANN_F_COST;
87 		roots[n_roots-1] = cost = kad_avg(1, &cost), cost->ext_flag |= KANN_F_COST;
88 		free(a->v);
89 		a->v = kad_compile_array(&a->n, n_roots, roots);
90 	}
91 	kad_ext_collate(a->n, a->v, &a->x, &a->g, &a->c);
92 	free(roots);
93 	return a;
94 }
95 
kann_clone(kann_t * a,int batch_size)96 kann_t *kann_clone(kann_t *a, int batch_size)
97 {
98 	kann_t *b;
99 	b = (kann_t*)calloc(1, sizeof(kann_t));
100 	b->n = a->n;
101 	b->v = kad_clone(a->n, a->v, batch_size);
102 	kad_ext_collate(b->n, b->v, &b->x, &b->g, &b->c);
103 	return b;
104 }
105 
kann_unroll_array(kann_t * a,int * len)106 kann_t *kann_unroll_array(kann_t *a, int *len)
107 {
108 	kann_t *b;
109 	b = (kann_t*)calloc(1, sizeof(kann_t));
110 	b->x = a->x, b->g = a->g, b->c = a->c; /* these arrays are shared */
111 	b->v = kad_unroll(a->n, a->v, &b->n, len);
112 	return b;
113 }
114 
kann_unroll(kann_t * a,...)115 kann_t *kann_unroll(kann_t *a, ...)
116 {
117 	kann_t *b;
118 	va_list ap;
119 	int i, n_pivots, *len;
120 	n_pivots = kad_n_pivots(a->n, a->v);
121 	len = (int*)calloc(n_pivots, sizeof(int));
122 	va_start(ap, a);
123 	for (i = 0; i < n_pivots; ++i) len[i] = va_arg(ap, int);
124 	va_end(ap);
125 	b = kann_unroll_array(a, len);
126 	free(len);
127 	return b;
128 }
129 
kann_delete_unrolled(kann_t * a)130 void kann_delete_unrolled(kann_t *a)
131 {
132 	if (a && a->mt) kann_mt(a, 0, 0);
133 	if (a && a->v) kad_delete(a->n, a->v);
134 	free(a);
135 }
136 
kann_delete(kann_t * a)137 void kann_delete(kann_t *a)
138 {
139 	if (a == 0) return;
140 	free(a->x); free(a->g); free(a->c);
141 	kann_delete_unrolled(a);
142 }
143 
kann_switch_core(kann_t * a,int is_train)144 static void kann_switch_core(kann_t *a, int is_train)
145 {
146 	int i;
147 	for (i = 0; i < a->n; ++i)
148 		if (a->v[i]->op == 12 && a->v[i]->n_child == 2)
149 			*(int32_t*)a->v[i]->ptr = !!is_train;
150 }
151 
152 #define chk_flg(flag, mask) ((mask) == 0 || ((flag) & (mask)))
153 #define chk_lbl(label, query) ((query) == 0 || (label) == (query))
154 
kann_find(const kann_t * a,uint32_t ext_flag,int32_t ext_label)155 int kann_find(const kann_t *a, uint32_t ext_flag, int32_t ext_label)
156 {
157 	int i, k, r = -1;
158 	for (i = k = 0; i < a->n; ++i)
159 		if (chk_flg(a->v[i]->ext_flag, ext_flag) && chk_lbl(a->v[i]->ext_label, ext_label))
160 			++k, r = i;
161 	return k == 1? r : k == 0? -1 : -2;
162 }
163 
kann_feed_bind(kann_t * a,uint32_t ext_flag,int32_t ext_label,float ** x)164 int kann_feed_bind(kann_t *a, uint32_t ext_flag, int32_t ext_label, float **x)
165 {
166 	int i, k;
167 	if (x == 0) return 0;
168 	for (i = k = 0; i < a->n; ++i)
169 		if (kad_is_feed(a->v[i]) && chk_flg(a->v[i]->ext_flag, ext_flag) && chk_lbl(a->v[i]->ext_label, ext_label))
170 			a->v[i]->x = x[k++];
171 	return k;
172 }
173 
kann_feed_dim(const kann_t * a,uint32_t ext_flag,int32_t ext_label)174 int kann_feed_dim(const kann_t *a, uint32_t ext_flag, int32_t ext_label)
175 {
176 	int i, k, n = 0;
177 	for (i = k = 0; i < a->n; ++i)
178 		if (kad_is_feed(a->v[i]) && chk_flg(a->v[i]->ext_flag, ext_flag) && chk_lbl(a->v[i]->ext_label, ext_label))
179 			++k, n = a->v[i]->n_d > 1? kad_len(a->v[i]) / a->v[i]->d[0] : a->v[i]->n_d == 1? a->v[i]->d[0] : 1;
180 	return k == 1? n : k == 0? -1 : -2;
181 }
182 
kann_cost_core(kann_t * a,int cost_label,int cal_grad)183 static float kann_cost_core(kann_t *a, int cost_label, int cal_grad)
184 {
185 	int i_cost;
186 	float cost;
187 	i_cost = kann_find(a, KANN_F_COST, cost_label);
188 	assert(i_cost >= 0);
189 	cost = *kad_eval_at(a->n, a->v, i_cost);
190 	if (cal_grad) kad_grad(a->n, a->v, i_cost);
191 	return cost;
192 }
193 
kann_eval(kann_t * a,uint32_t ext_flag,int ext_label)194 int kann_eval(kann_t *a, uint32_t ext_flag, int ext_label)
195 {
196 	int i, k;
197 	for (i = k = 0; i < a->n; ++i)
198 		if (chk_flg(a->v[i]->ext_flag, ext_flag) && chk_lbl(a->v[i]->ext_label, ext_label))
199 			++k, a->v[i]->tmp = 1;
200 	kad_eval_marked(a->n, a->v);
201 	return k;
202 }
203 
kann_rnn_start(kann_t * a)204 void kann_rnn_start(kann_t *a)
205 {
206 	int i;
207 	kann_set_batch_size(a, 1);
208 	for (i = 0; i < a->n; ++i) {
209 		kad_node_t *p = a->v[i];
210 		if (p->pre) { /* NB: BE CAREFUL of the interaction between kann_rnn_start() and kann_set_batch_size() */
211 			kad_node_t *q = p->pre;
212 			if (q->x) memcpy(p->x, q->x, kad_len(p) * sizeof(float));
213 			else memset(p->x, 0, kad_len(p) * sizeof(float));
214 			if (q->n_child > 0) free(q->x);
215 			q->x = p->x;
216 		}
217 	}
218 }
219 
kann_rnn_end(kann_t * a)220 void kann_rnn_end(kann_t *a)
221 {
222 	int i;
223 	kad_ext_sync(a->n, a->v, a->x, a->g, a->c);
224 	for (i = 0; i < a->n; ++i)
225 		if (a->v[i]->pre && a->v[i]->pre->n_child > 0)
226 			a->v[i]->pre->x = (float*)calloc(kad_len(a->v[i]->pre), sizeof(float));
227 }
228 
kann_class_error_core(const kann_t * ann,int * base)229 static int kann_class_error_core(const kann_t *ann, int *base)
230 {
231 	int i, j, k, m, n, off, n_err = 0;
232 	for (i = 0, *base = 0; i < ann->n; ++i) {
233 		kad_node_t *p = ann->v[i];
234 		if (((p->op == 13 && (p->n_child == 2 || p->n_child == 3)) || (p->op == 22 && p->n_child == 2)) && p->n_d == 0) { /* ce_bin or ce_multi */
235 			kad_node_t *x = p->child[0], *t = p->child[1];
236 			n = t->d[t->n_d - 1], m = kad_len(t) / n;
237 			for (j = off = 0; j < m; ++j, off += n) {
238 				float t_sum = 0.0f, t_min = 1.0f, t_max = 0.0f, x_max = 0.0f, x_min = 1.0f;
239 				int x_max_k = -1, t_max_k = -1;
240 				for (k = 0; k < n; ++k) {
241 					float xk = x->x[off+k], tk = t->x[off+k];
242 					t_sum += tk;
243 					t_min = t_min < tk? t_min : tk;
244 					x_min = x_min < xk? x_min : xk;
245 					if (t_max < tk) t_max = tk, t_max_k = k;
246 					if (x_max < xk) x_max = xk, x_max_k = k;
247 				}
248 				if (t_sum - 1.0f == 0 && t_min >= 0.0f && x_min >= 0.0f && x_max <= 1.0f) {
249 					++(*base);
250 					n_err += (x_max_k != t_max_k);
251 				}
252 			}
253 		}
254 	}
255 	return n_err;
256 }
257 
258 /*************************
259  * @@MT: multi-threading *
260  *************************/
261 
262 #ifdef HAVE_PTHREAD
263 #include <pthread.h>
264 
265 struct mtaux_t;
266 
267 typedef struct { /* per-worker data */
268 	kann_t *a;
269 	float cost;
270 	int action;
271 	pthread_t tid;
272 	struct mtaux_t *g;
273 } mtaux1_t;
274 
275 typedef struct mtaux_t { /* cross-worker data */
276 	int n_threads, max_batch_size;
277 	int cal_grad, cost_label, eval_out;
278 	volatile int n_idle; /* we will be busy waiting on this, so volatile necessary */
279 	pthread_mutex_t mtx;
280 	pthread_cond_t cv;
281 	mtaux1_t *mt;
282 } mtaux_t;
283 
mt_worker(void * data)284 static void *mt_worker(void *data) /* pthread worker */
285 {
286 	mtaux1_t *mt1 = (mtaux1_t*)data;
287 	mtaux_t *mt = mt1->g;
288 	for (;;) {
289 		int action;
290 		pthread_mutex_lock(&mt->mtx);
291 		mt1->action = 0;
292 		++mt->n_idle;
293 		while (mt1->action == 0)
294 			pthread_cond_wait(&mt->cv, &mt->mtx);
295 		action = mt1->action;
296 		pthread_mutex_unlock(&mt->mtx);
297 		if (action == -1) break;
298 
299 		if (mt->eval_out) kann_eval(mt1->a, KANN_F_OUT, 0);
300 		else mt1->cost = kann_cost_core(mt1->a, mt->cost_label, mt->cal_grad);
301 	}
302 	pthread_exit(0);
303 }
304 
mt_destroy(mtaux_t * mt)305 static void mt_destroy(mtaux_t *mt) /* de-allocate an entire mtaux_t struct */
306 {
307 	int i;
308 	pthread_mutex_lock(&mt->mtx);
309 	mt->n_idle = 0;
310 	for (i = 1; i < mt->n_threads; ++i) mt->mt[i].action = -1;
311 	pthread_cond_broadcast(&mt->cv);
312 	pthread_mutex_unlock(&mt->mtx);
313 	for (i = 1; i < mt->n_threads; ++i) pthread_join(mt->mt[i].tid, 0);
314 	for (i = 0; i < mt->n_threads; ++i) kann_delete(mt->mt[i].a);
315 	free(mt->mt);
316 	pthread_cond_destroy(&mt->cv);
317 	pthread_mutex_destroy(&mt->mtx);
318 	free(mt);
319 }
320 
kann_mt(kann_t * ann,int n_threads,int max_batch_size)321 void kann_mt(kann_t *ann, int n_threads, int max_batch_size)
322 {
323 	mtaux_t *mt;
324 	int i, k;
325 
326 	if (n_threads <= 1) {
327 		if (ann->mt) mt_destroy((mtaux_t*)ann->mt);
328 		ann->mt = 0;
329 		return;
330 	}
331 	if (n_threads > max_batch_size) n_threads = max_batch_size;
332 	if (n_threads <= 1) return;
333 
334 	mt = (mtaux_t*)calloc(1, sizeof(mtaux_t));
335 	mt->n_threads = n_threads, mt->max_batch_size = max_batch_size;
336 	pthread_mutex_init(&mt->mtx, 0);
337 	pthread_cond_init(&mt->cv, 0);
338 	mt->mt = (mtaux1_t*)calloc(n_threads, sizeof(mtaux1_t));
339 	for (i = k = 0; i < n_threads; ++i) {
340 		int size = (max_batch_size - k) / (n_threads - i);
341 		mt->mt[i].a = kann_clone(ann, size);
342 		mt->mt[i].g = mt;
343 		k += size;
344 	}
345 	for (i = 1; i < n_threads; ++i)
346 		pthread_create(&mt->mt[i].tid, 0, mt_worker, &mt->mt[i]);
347 	while (mt->n_idle < n_threads - 1); /* busy waiting until all threads in sync */
348 	ann->mt = mt;
349 }
350 
mt_kickoff(kann_t * a,int cost_label,int cal_grad,int eval_out)351 static void mt_kickoff(kann_t *a, int cost_label, int cal_grad, int eval_out)
352 {
353 	mtaux_t *mt = (mtaux_t*)a->mt;
354 	int i, j, k, B, n_var;
355 
356 	B = kad_sync_dim(a->n, a->v, -1); /* get the current batch size */
357 	assert(B <= mt->max_batch_size); /* TODO: can be relaxed */
358 	n_var = kann_size_var(a);
359 
360 	pthread_mutex_lock(&mt->mtx);
361 	mt->cost_label = cost_label, mt->cal_grad = cal_grad, mt->eval_out = eval_out;
362 	for (i = k = 0; i < mt->n_threads; ++i) {
363 		int size = (B - k) / (mt->n_threads - i);
364 		for (j = 0; j < a->n; ++j)
365 			if (kad_is_feed(a->v[j]))
366 				mt->mt[i].a->v[j]->x = &a->v[j]->x[k * kad_len(a->v[j]) / a->v[j]->d[0]];
367 		kad_sync_dim(mt->mt[i].a->n, mt->mt[i].a->v, size); /* TODO: we can point ->x to internal nodes, too */
368 		k += size;
369 		memcpy(mt->mt[i].a->x, a->x, n_var * sizeof(float));
370 		mt->mt[i].action = 1;
371 	}
372 	mt->n_idle = 0;
373 	pthread_cond_broadcast(&mt->cv);
374 	pthread_mutex_unlock(&mt->mtx);
375 }
376 
kann_cost(kann_t * a,int cost_label,int cal_grad)377 float kann_cost(kann_t *a, int cost_label, int cal_grad)
378 {
379 	mtaux_t *mt = (mtaux_t*)a->mt;
380 	int i, j, B, k, n_var;
381 	float cost;
382 
383 	if (mt == 0) return kann_cost_core(a, cost_label, cal_grad);
384 	B = kad_sync_dim(a->n, a->v, -1); /* get the current batch size */
385 	n_var = kann_size_var(a);
386 
387 	mt_kickoff(a, cost_label, cal_grad, 0);
388 	mt->mt[0].cost = kann_cost_core(mt->mt[0].a, cost_label, cal_grad);
389 	while (mt->n_idle < mt->n_threads - 1); /* busy waiting until all threads in sync */
390 
391 	memset(a->g, 0, n_var * sizeof(float)); /* TODO: check if this is necessary when cal_grad is false */
392 	for (i = k = 0, cost = 0.0f; i < mt->n_threads; ++i) {
393 		int size = (B - k) / (mt->n_threads - i);
394 		cost += mt->mt[i].cost * size / B;
395 		kad_saxpy(n_var, (float)size / B, mt->mt[i].a->g, a->g);
396 		k += size;
397 	}
398 	for (j = 0; j < a->n; ++j) { /* copy values back at recurrent nodes (needed by textgen; TODO: temporary solution) */
399 		kad_node_t *p = a->v[j];
400 		if (p->pre && p->n_d >= 2 && p->d[0] == B) {
401 			for (i = k = 0; i < mt->n_threads; ++i) {
402 				kad_node_t *q = mt->mt[i].a->v[j];
403 				memcpy(&p->x[k], q->x, kad_len(q) * sizeof(float));
404 				k += kad_len(q);
405 			}
406 		}
407 	}
408 	return cost;
409 }
410 
kann_eval_out(kann_t * a)411 int kann_eval_out(kann_t *a)
412 {
413 	mtaux_t *mt = (mtaux_t*)a->mt;
414 	int j, B, n_eval;
415 	if (mt == 0) return kann_eval(a, KANN_F_OUT, 0);
416 	B = kad_sync_dim(a->n, a->v, -1); /* get the current batch size */
417 	mt_kickoff(a, 0, 0, 1);
418 	n_eval = kann_eval(mt->mt[0].a, KANN_F_OUT, 0);
419 	while (mt->n_idle < mt->n_threads - 1); /* busy waiting until all threads in sync */
420 	for (j = 0; j < a->n; ++j) { /* copy output values back */
421 		kad_node_t *p = a->v[j];
422 		if (p->ext_flag & KANN_F_OUT) {
423 			int i, t, k, d0 = p->d[0] / B, d1 = 1; /* for RNN, p->d[0] may equal unroll_len * batch_size */
424 			assert(p->d[0] % B == 0);
425 			for (i = 1; i < p->n_d; ++i) d1 *= p->d[i];
426 			for (i = 0; i < d0; ++i) {
427 				for (t = k = 0; t < mt->n_threads; ++t) { /* similar to the forward pass of kad_op_concat() */
428 					kad_node_t *q = mt->mt[t].a->v[j];
429 					int size = q->d[0] / d0;
430 					memcpy(&p->x[(i * B + k) * d1], &q->x[i * size * d1], size * d1 * sizeof(float));
431 					k += size;
432 				}
433 			}
434 		}
435 	}
436 	return n_eval;
437 }
438 
kann_class_error(const kann_t * ann,int * base)439 int kann_class_error(const kann_t *ann, int *base)
440 {
441 	mtaux_t *mt = (mtaux_t*)ann->mt;
442 	int i, n_err = 0, b = 0;
443 	if (mt == 0) return kann_class_error_core(ann, base);
444 	for (i = 0; i < mt->n_threads; ++i) {
445 		n_err += kann_class_error_core(mt->mt[i].a, &b);
446 		*base += b;
447 	}
448 	return n_err;
449 }
450 
kann_switch(kann_t * ann,int is_train)451 void kann_switch(kann_t *ann, int is_train)
452 {
453 	mtaux_t *mt = (mtaux_t*)ann->mt;
454 	int i;
455 	if (mt == 0) {
456 		kann_switch_core(ann, is_train);
457 		return;
458 	}
459 	for (i = 0; i < mt->n_threads; ++i)
460 		kann_switch_core(mt->mt[i].a, is_train);
461 }
462 #else
kann_mt(kann_t * ann,int n_threads,int max_batch_size)463 void kann_mt(kann_t *ann, int n_threads, int max_batch_size) {}
kann_cost(kann_t * a,int cost_label,int cal_grad)464 float kann_cost(kann_t *a, int cost_label, int cal_grad) { return kann_cost_core(a, cost_label, cal_grad); }
kann_eval_out(kann_t * a)465 int kann_eval_out(kann_t *a) { return kann_eval(a, KANN_F_OUT, 0); }
kann_class_error(const kann_t * a,int * base)466 int kann_class_error(const kann_t *a, int *base) { return kann_class_error_core(a, base); }
kann_switch(kann_t * ann,int is_train)467 void kann_switch(kann_t *ann, int is_train) { return kann_switch_core(ann, is_train); }
468 #endif
469 
470 /***********************
471  *** @@IO: model I/O ***
472  ***********************/
473 
474 #define KANN_MAGIC "KAN\1"
475 
kann_save_fp(FILE * fp,kann_t * ann)476 void kann_save_fp(FILE *fp, kann_t *ann)
477 {
478 	kann_set_batch_size(ann, 1);
479 	fwrite(KANN_MAGIC, 1, 4, fp);
480 	kad_save(fp, ann->n, ann->v);
481 	fwrite(ann->x, sizeof(float), kann_size_var(ann), fp);
482 	fwrite(ann->c, sizeof(float), kann_size_const(ann), fp);
483 }
484 
kann_save(const char * fn,kann_t * ann)485 void kann_save(const char *fn, kann_t *ann)
486 {
487 	FILE *fp;
488 	fp = fn && strcmp(fn, "-")? fopen(fn, "wb") : stdout;
489 	kann_save_fp(fp, ann);
490 	fclose(fp);
491 }
492 
kann_load_fp(FILE * fp)493 kann_t *kann_load_fp(FILE *fp)
494 {
495 	char magic[4];
496 	kann_t *ann;
497 	int n_var, n_const;
498 
499 	(void) !fread(magic, 1, 4, fp);
500 	if (strncmp(magic, KANN_MAGIC, 4) != 0) {
501 		return 0;
502 	}
503 	ann = (kann_t*)calloc(1, sizeof(kann_t));
504 	ann->v = kad_load(fp, &ann->n);
505 	n_var = kad_size_var(ann->n, ann->v);
506 	n_const = kad_size_const(ann->n, ann->v);
507 	ann->x = (float*)malloc(n_var * sizeof(float));
508 	ann->g = (float*)calloc(n_var, sizeof(float));
509 	ann->c = (float*)malloc(n_const * sizeof(float));
510 	(void) !fread(ann->x, sizeof(float), n_var, fp);
511 	(void) !fread(ann->c, sizeof(float), n_const, fp);
512 	kad_ext_sync(ann->n, ann->v, ann->x, ann->g, ann->c);
513 	return ann;
514 }
515 
kann_load(const char * fn)516 kann_t *kann_load(const char *fn)
517 {
518 	FILE *fp;
519 	kann_t *ann;
520 	fp = fn && strcmp(fn, "-")? fopen(fn, "rb") : stdin;
521 	ann = kann_load_fp(fp);
522 	fclose(fp);
523 	return ann;
524 }
525 
526 /**********************************************
527  *** @@LAYER: layers and model generation ***
528  **********************************************/
529 
530 /********** General but more complex APIs **********/
531 
kann_new_leaf_array(int * offset,kad_node_p * par,uint8_t flag,float x0_01,int n_d,int32_t d[KAD_MAX_DIM])532 kad_node_t *kann_new_leaf_array(int *offset, kad_node_p *par, uint8_t flag, float x0_01, int n_d, int32_t d[KAD_MAX_DIM])
533 {
534 	int i, len, off = offset && par? *offset : -1;
535 	kad_node_t *p;
536 
537 	if (off >= 0 && par[off]) return par[(*offset)++];
538 	p = (kad_node_t*)calloc(1, sizeof(kad_node_t));
539 	p->n_d = n_d, p->flag = flag;
540 	memcpy(p->d, d, n_d * sizeof(int32_t));
541 	len = kad_len(p);
542 	p->x = (float*)calloc(len, sizeof(float));
543 	if (p->n_d <= 1) {
544 		for (i = 0; i < len; ++i)
545 			p->x[i] = x0_01;
546 	} else {
547 		double sdev_inv;
548 		sdev_inv = 1.0 / sqrt((double)len / p->d[0]);
549 		for (i = 0; i < len; ++i)
550 			p->x[i] = (float)(kad_drand_normal(0) * sdev_inv);
551 	}
552 	if (off >= 0) par[off] = p, ++(*offset);
553 	return p;
554 }
555 
kann_new_leaf2(int * offset,kad_node_p * par,uint8_t flag,float x0_01,int n_d,...)556 kad_node_t *kann_new_leaf2(int *offset, kad_node_p *par, uint8_t flag, float x0_01, int n_d, ...)
557 {
558 	int32_t i, d[KAD_MAX_DIM];
559 	va_list ap;
560 	va_start(ap, n_d); for (i = 0; i < n_d; ++i) d[i] = va_arg(ap, int); va_end(ap);
561 	return kann_new_leaf_array(offset, par, flag, x0_01, n_d, d);
562 }
563 
kann_layer_dense2(int * offset,kad_node_p * par,kad_node_t * in,int n1)564 kad_node_t *kann_layer_dense2(int *offset, kad_node_p *par, kad_node_t *in, int n1)
565 {
566 	int n0;
567 	kad_node_t *w, *b;
568 	n0 = in->n_d >= 2? kad_len(in) / in->d[0] : kad_len(in);
569 	w = kann_new_leaf2(offset, par, KAD_VAR, 0.0f, 2, n1, n0);
570 	b = kann_new_leaf2(offset, par, KAD_VAR, 0.0f, 1, n1);
571 	return kad_add(kad_cmul(in, w), b);
572 }
573 
kann_layer_dropout2(int * offset,kad_node_p * par,kad_node_t * t,float r)574 kad_node_t *kann_layer_dropout2(int *offset, kad_node_p *par, kad_node_t *t, float r)
575 {
576 	kad_node_t *x[2], *cr;
577 	cr = kann_new_leaf2(offset, par, KAD_CONST, r, 0);
578 	x[0] = t, x[1] = kad_dropout(t, cr);
579 	return kad_switch(2, x);
580 }
581 
kann_layer_layernorm2(int * offset,kad_node_t ** par,kad_node_t * in)582 kad_node_t *kann_layer_layernorm2(int *offset, kad_node_t **par, kad_node_t *in)
583 {
584 	int n0;
585 	kad_node_t *alpha, *beta;
586 	n0 = in->n_d >= 2? kad_len(in) / in->d[0] : kad_len(in);
587 	alpha = kann_new_leaf2(offset, par, KAD_VAR, 1.0f, 1, n0);
588 	beta  = kann_new_leaf2(offset, par, KAD_VAR, 0.0f, 1, n0);
589 	return kad_add(kad_mul(kad_stdnorm(in), alpha), beta);
590 }
591 
cmul_norm2(int * offset,kad_node_t ** par,kad_node_t * x,kad_node_t * w,int use_norm)592 static inline kad_node_t *cmul_norm2(int *offset, kad_node_t **par, kad_node_t *x, kad_node_t *w, int use_norm)
593 {
594 	return use_norm? kann_layer_layernorm2(offset, par, kad_cmul(x, w)) : kad_cmul(x, w);
595 }
596 
kann_layer_rnn2(int * offset,kad_node_t ** par,kad_node_t * in,kad_node_t * h0,int rnn_flag)597 kad_node_t *kann_layer_rnn2(int *offset, kad_node_t **par, kad_node_t *in, kad_node_t *h0, int rnn_flag)
598 {
599 	int n0, n1 = h0->d[h0->n_d-1], use_norm = !!(rnn_flag & KANN_RNN_NORM);
600 	kad_node_t *t, *w, *u, *b, *out;
601 
602 	u = kann_new_leaf2(offset, par, KAD_VAR, 0.0f, 2, n1, n1);
603 	b = kann_new_leaf2(offset, par, KAD_VAR, 0.0f, 1, n1);
604 	t = cmul_norm2(offset, par, h0, u, use_norm);
605 	if (in) {
606 		n0 = in->n_d >= 2? kad_len(in) / in->d[0] : kad_len(in);
607 		w = kann_new_leaf2(offset, par, KAD_VAR, 0.0f, 2, n1, n0);
608 		t = kad_add(cmul_norm2(offset, par, in, w, use_norm), t);
609 	}
610 	out = kad_tanh(kad_add(t, b));
611 	out->pre = h0;
612 	return out;
613 }
614 
kann_layer_gru2(int * offset,kad_node_t ** par,kad_node_t * in,kad_node_t * h0,int rnn_flag)615 kad_node_t *kann_layer_gru2(int *offset, kad_node_t **par, kad_node_t *in, kad_node_t *h0, int rnn_flag)
616 {
617 	int n0 = 0, n1 = h0->d[h0->n_d-1], use_norm = !!(rnn_flag & KANN_RNN_NORM);
618 	kad_node_t *t, *r, *z, *w, *u, *b, *s, *out;
619 
620 	if (in) n0 = in->n_d >= 2? kad_len(in) / in->d[0] : kad_len(in);
621 	/* z = sigm(x_t * W_z + h_{t-1} * U_z + b_z) */
622 	u = kann_new_leaf2(offset, par, KAD_VAR, 0.0f, 2, n1, n1);
623 	b = kann_new_leaf2(offset, par, KAD_VAR, 0.0f, 1, n1);
624 	t = cmul_norm2(offset, par, h0, u, use_norm);
625 	if (in) {
626 		w = kann_new_leaf2(offset, par, KAD_VAR, 0.0f, 2, n1, n0);
627 		t = kad_add(cmul_norm2(offset, par, in, w, use_norm), t);
628 	}
629 	z = kad_sigm(kad_add(t, b));
630 	/* r = sigm(x_t * W_r + h_{t-1} * U_r + b_r) */
631 	u = kann_new_leaf2(offset, par, KAD_VAR, 0.0f, 2, n1, n1);
632 	b = kann_new_leaf2(offset, par, KAD_VAR, 0.0f, 1, n1);
633 	t = cmul_norm2(offset, par, h0, u, use_norm);
634 	if (in) {
635 		w = kann_new_leaf2(offset, par, KAD_VAR, 0.0f, 2, n1, n0);
636 		t = kad_add(cmul_norm2(offset, par, in, w, use_norm), t);
637 	}
638 	r = kad_sigm(kad_add(t, b));
639 	/* s = tanh(x_t * W_s + (h_{t-1} # r) * U_s + b_s) */
640 	u = kann_new_leaf2(offset, par, KAD_VAR, 0.0f, 2, n1, n1);
641 	b = kann_new_leaf2(offset, par, KAD_VAR, 0.0f, 1, n1);
642 	t = cmul_norm2(offset, par, kad_mul(r, h0), u, use_norm);
643 	if (in) {
644 		w = kann_new_leaf2(offset, par, KAD_VAR, 0.0f, 2, n1, n0);
645 		t = kad_add(cmul_norm2(offset, par, in, w, use_norm), t);
646 	}
647 	s = kad_tanh(kad_add(t, b));
648 	/* h_t = z # h_{t-1} + (1 - z) # s */
649 	out = kad_add(kad_mul(kad_1minus(z), s), kad_mul(z, h0));
650 	out->pre = h0;
651 	return out;
652 }
653 
654 /********** APIs without offset & par **********/
655 
kann_new_leaf(uint8_t flag,float x0_01,int n_d,...)656 kad_node_t *kann_new_leaf(uint8_t flag, float x0_01, int n_d, ...)
657 {
658 	int32_t i, d[KAD_MAX_DIM];
659 	va_list ap;
660 	va_start(ap, n_d); for (i = 0; i < n_d; ++i) d[i] = va_arg(ap, int); va_end(ap);
661 	return kann_new_leaf_array(0, 0, flag, x0_01, n_d, d);
662 }
663 
kann_new_scalar(uint8_t flag,float x)664 kad_node_t *kann_new_scalar(uint8_t flag, float x) { return kann_new_leaf(flag, x, 0); }
kann_new_weight(int n_row,int n_col)665 kad_node_t *kann_new_weight(int n_row, int n_col) { return kann_new_leaf(KAD_VAR, 0.0f, 2, n_row, n_col); }
kann_new_vec(int n,float x)666 kad_node_t *kann_new_vec(int n, float x) { return kann_new_leaf(KAD_VAR, x, 1, n); }
kann_new_bias(int n)667 kad_node_t *kann_new_bias(int n) { return kann_new_vec(n, 0.0f); }
kann_new_weight_conv2d(int n_out,int n_in,int k_row,int k_col)668 kad_node_t *kann_new_weight_conv2d(int n_out, int n_in, int k_row, int k_col) { return kann_new_leaf(KAD_VAR, 0.0f, 4, n_out, n_in, k_row, k_col); }
kann_new_weight_conv1d(int n_out,int n_in,int kernel_len)669 kad_node_t *kann_new_weight_conv1d(int n_out, int n_in, int kernel_len) { return kann_new_leaf(KAD_VAR, 0.0f, 3, n_out, n_in, kernel_len); }
670 
kann_layer_input(int n1)671 kad_node_t *kann_layer_input(int n1)
672 {
673 	kad_node_t *t;
674 	t = kad_feed(2, 1, n1);
675 	t->ext_flag |= KANN_F_IN;
676 	return t;
677 }
678 
kann_layer_dense(kad_node_t * in,int n1)679 kad_node_t *kann_layer_dense(kad_node_t *in, int n1) { return kann_layer_dense2(0, 0, in, n1); }
kann_layer_dropout(kad_node_t * t,float r)680 kad_node_t *kann_layer_dropout(kad_node_t *t, float r) { return kann_layer_dropout2(0, 0, t, r); }
kann_layer_layernorm(kad_node_t * in)681 kad_node_t *kann_layer_layernorm(kad_node_t *in) { return kann_layer_layernorm2(0, 0, in); }
682 
kann_layer_rnn(kad_node_t * in,int n1,int rnn_flag)683 kad_node_t *kann_layer_rnn(kad_node_t *in, int n1, int rnn_flag)
684 {
685 	kad_node_t *h0;
686 	h0 = (rnn_flag & KANN_RNN_VAR_H0)? kad_var(0, 0, 2, 1, n1) : kad_const(0, 2, 1, n1);
687 	h0->x = (float*)calloc(n1, sizeof(float));
688 	return kann_layer_rnn2(0, 0, in, h0, rnn_flag);
689 }
690 
kann_layer_gru(kad_node_t * in,int n1,int rnn_flag)691 kad_node_t *kann_layer_gru(kad_node_t *in, int n1, int rnn_flag)
692 {
693 	kad_node_t *h0;
694 	h0 = (rnn_flag & KANN_RNN_VAR_H0)? kad_var(0, 0, 2, 1, n1) : kad_const(0, 2, 1, n1);
695 	h0->x = (float*)calloc(n1, sizeof(float));
696 	return kann_layer_gru2(0, 0, in, h0, rnn_flag);
697 }
698 
kann_cmul_norm(kad_node_t * x,kad_node_t * w)699 static kad_node_t *kann_cmul_norm(kad_node_t *x, kad_node_t *w)
700 {
701 	return kann_layer_layernorm(kad_cmul(x, w));
702 }
703 
kann_layer_lstm(kad_node_t * in,int n1,int rnn_flag)704 kad_node_t *kann_layer_lstm(kad_node_t *in, int n1, int rnn_flag)
705 {
706 	int n0;
707 	kad_node_t *i, *f, *o, *g, *w, *u, *b, *h0, *c0, *c, *out;
708 	kad_node_t *(*cmul)(kad_node_t*, kad_node_t*) = (rnn_flag & KANN_RNN_NORM)? kann_cmul_norm : kad_cmul;
709 
710 	n0 = in->n_d >= 2? kad_len(in) / in->d[0] : kad_len(in);
711 	h0 = (rnn_flag & KANN_RNN_VAR_H0)? kad_var(0, 0, 2, 1, n1) : kad_const(0, 2, 1, n1);
712 	h0->x = (float*)calloc(n1, sizeof(float));
713 	c0 = (rnn_flag & KANN_RNN_VAR_H0)? kad_var(0, 0, 2, 1, n1) : kad_const(0, 2, 1, n1);
714 	c0->x = (float*)calloc(n1, sizeof(float));
715 
716 	/* i = sigm(x_t * W_i + h_{t-1} * U_i + b_i) */
717 	w = kann_new_weight(n1, n0);
718 	u = kann_new_weight(n1, n1);
719 	b = kann_new_bias(n1);
720 	i = kad_sigm(kad_add(kad_add(cmul(in, w), cmul(h0, u)), b));
721 	/* f = sigm(x_t * W_f + h_{t-1} * U_f + b_f) */
722 	w = kann_new_weight(n1, n0);
723 	u = kann_new_weight(n1, n1);
724 	b = kann_new_vec(n1, 1.0f); /* see Jozefowicz et al on using a large bias */
725 	f = kad_sigm(kad_add(kad_add(cmul(in, w), cmul(h0, u)), b));
726 	/* o = sigm(x_t * W_o + h_{t-1} * U_o + b_o) */
727 	w = kann_new_weight(n1, n0);
728 	u = kann_new_weight(n1, n1);
729 	b = kann_new_bias(n1);
730 	o = kad_sigm(kad_add(kad_add(cmul(in, w), cmul(h0, u)), b));
731 	/* g = tanh(x_t * W_g + h_{t-1} * U_g + b_g) */
732 	w = kann_new_weight(n1, n0);
733 	u = kann_new_weight(n1, n1);
734 	b = kann_new_bias(n1);
735 	g = kad_tanh(kad_add(kad_add(cmul(in, w), cmul(h0, u)), b));
736 	/* c_t = c_{t-1} # f + g # i */
737 	c = kad_add(kad_mul(f, c0), kad_mul(g, i)); /* can't be kad_mul(c0, f)!!! */
738 	c->pre = c0;
739 	/* h_t = tanh(c_t) # o */
740 	if (rnn_flag & KANN_RNN_NORM) c = kann_layer_layernorm(c); /* see Ba et al (2016) about how to apply layer normalization to LSTM */
741 	out = kad_mul(kad_tanh(c), o);
742 	out->pre = h0;
743 	return out;
744 }
745 
kann_layer_conv2d(kad_node_t * in,int n_flt,int k_rows,int k_cols,int stride_r,int stride_c,int pad_r,int pad_c)746 kad_node_t *kann_layer_conv2d(kad_node_t *in, int n_flt, int k_rows, int k_cols, int stride_r, int stride_c, int pad_r, int pad_c)
747 {
748 	kad_node_t *w;
749 	w = kann_new_weight_conv2d(n_flt, in->d[1], k_rows, k_cols);
750 	return kad_conv2d(in, w, stride_r, stride_c, pad_r, pad_c);
751 }
752 
kann_layer_conv1d(kad_node_t * in,int n_flt,int k_size,int stride,int pad)753 kad_node_t *kann_layer_conv1d(kad_node_t *in, int n_flt, int k_size, int stride, int pad)
754 {
755 	kad_node_t *w;
756 	w = kann_new_weight_conv1d(n_flt, in->d[1], k_size);
757 	return kad_conv1d(in, w, stride, pad);
758 }
759 
kann_layer_cost(kad_node_t * t,int n_out,int cost_type)760 kad_node_t *kann_layer_cost(kad_node_t *t, int n_out, int cost_type)
761 {
762 	kad_node_t *cost = 0, *truth = 0;
763 	assert(cost_type == KANN_C_CEB || cost_type == KANN_C_CEM || cost_type == KANN_C_CEB_NEG || cost_type == KANN_C_MSE);
764 	t = kann_layer_dense(t, n_out);
765 	truth = kad_feed(2, 1, n_out), truth->ext_flag |= KANN_F_TRUTH;
766 
767 	if (cost_type == KANN_C_MSE) {
768 		cost = kad_mse(t, truth);
769 	} else if (cost_type == KANN_C_CEB) {
770 		t = kad_sigm(t);
771 		cost = kad_ce_bin(t, truth);
772 	} else if (cost_type == KANN_C_CEB_NEG) {
773 		t = kad_tanh(t);
774 		cost = kad_ce_bin_neg(t, truth);
775 	} else if (cost_type == KANN_C_CEM) {
776 		t = kad_softmax(t);
777 		cost = kad_ce_multi(t, truth);
778 	}
779 	else {
780 		assert (0);
781 	}
782 
783 	t->ext_flag |= KANN_F_OUT;
784 	cost->ext_flag |= KANN_F_COST;
785 
786 	return cost;
787 }
788 
kann_shuffle(int n,int * s)789 void kann_shuffle(int n, int *s)
790 {
791 	int i, j, t;
792 	for (i = 0; i < n; ++i) s[i] = i;
793 	for (i = n; i > 0; --i) {
794 		j = (int)(i * kad_drand(0));
795 		t = s[j], s[j] = s[i-1], s[i-1] = t;
796 	}
797 }
798 
799 /***************************
800  *** @@MIN: minimization ***
801  ***************************/
802 
803 #ifdef __SSE__
804 #include <xmmintrin.h>
805 
kann_RMSprop(int n,float h0,const float * h,float decay,const float * g,float * t,float * r)806 void kann_RMSprop(int n, float h0, const float *h, float decay, const float *g, float *t, float *r)
807 {
808 	int i, n4 = n>>2<<2;
809 	__m128 vh, vg, vr, vt, vd, vd1, tmp, vtiny;
810 	vh = _mm_set1_ps(h0);
811 	vd = _mm_set1_ps(decay);
812 	vd1 = _mm_set1_ps(1.0f - decay);
813 	vtiny = _mm_set1_ps(1e-6f);
814 	for (i = 0; i < n4; i += 4) {
815 		vt = _mm_loadu_ps(&t[i]);
816 		vr = _mm_loadu_ps(&r[i]);
817 		vg = _mm_loadu_ps(&g[i]);
818 		if (h) vh = _mm_loadu_ps(&h[i]);
819 		vr = _mm_add_ps(_mm_mul_ps(vd1, _mm_mul_ps(vg, vg)), _mm_mul_ps(vd, vr));
820 		_mm_storeu_ps(&r[i], vr);
821 		tmp = _mm_sub_ps(vt, _mm_mul_ps(_mm_mul_ps(vh, _mm_rsqrt_ps(_mm_add_ps(vtiny, vr))), vg));
822 		_mm_storeu_ps(&t[i], tmp);
823 	}
824 	for (; i < n; ++i) {
825 		r[i] = (1. - decay) * g[i] * g[i] + decay * r[i];
826 		t[i] -= (h? h[i] : h0) / sqrtf(1e-6f + r[i]) * g[i];
827 	}
828 }
829 #else
kann_RMSprop(int n,float h0,const float * h,float decay,const float * g,float * t,float * r)830 void kann_RMSprop(int n, float h0, const float *h, float decay, const float *g, float *t, float *r)
831 {
832 	int i;
833 	for (i = 0; i < n; ++i) {
834 		float lr = h? h[i] : h0;
835 		r[i] = (1.0f - decay) * g[i] * g[i] + decay * r[i];
836 		t[i] -= lr / sqrtf(1e-6f + r[i]) * g[i];
837 	}
838 }
839 #endif
840 
kann_grad_clip(float thres,int n,float * g)841 float kann_grad_clip(float thres, int n, float *g)
842 {
843 	int i;
844 	double s2 = 0.0;
845 	for (i = 0; i < n; ++i)
846 		s2 += g[i] * g[i];
847 	s2 = sqrt(s2);
848 	if (s2 > thres)
849 		for (i = 0, s2 = 1.0 / s2; i < n; ++i)
850 			g[i] *= (float)s2;
851 	return (float)s2 / thres;
852 }
853 
854 /****************************************************************
855  *** @@XY: simpler API for network with a single input/output ***
856  ****************************************************************/
857 
kann_train_fnn1(kann_t * ann,float lr,int mini_size,int max_epoch,int max_drop_streak,float frac_val,int n,float ** _x,float ** _y,kann_train_cb cb,void * ud)858 int kann_train_fnn1(kann_t *ann, float lr, int mini_size, int max_epoch,
859 					int max_drop_streak, float frac_val, int n,
860 					float **_x, float **_y,
861 					kann_train_cb cb, void *ud)
862 {
863 	int i, j, *shuf, n_train, n_val, n_in, n_out, n_var, n_const, drop_streak = 0, min_set = 0;
864 	float **x, **y, *x1, *y1, *r, min_val_cost = FLT_MAX, *min_x, *min_c;
865 
866 	n_in = kann_dim_in(ann);
867 	n_out = kann_dim_out(ann);
868 	if (n_in < 0 || n_out < 0) return -1;
869 	n_var = kann_size_var(ann);
870 	n_const = kann_size_const(ann);
871 	r = (float*)calloc(n_var, sizeof(float));
872 	shuf = (int*)malloc(n * sizeof(int));
873 	x = (float**)malloc(n * sizeof(float*));
874 	y = (float**)malloc(n * sizeof(float*));
875 	kann_shuffle(n, shuf);
876 	for (j = 0; j < n; ++j)
877 		x[j] = _x[shuf[j]], y[j] = _y[shuf[j]];
878 	n_val = (int)(n * frac_val);
879 	n_train = n - n_val;
880 	min_x = (float*)malloc(n_var * sizeof(float));
881 	min_c = (float*)malloc(n_const * sizeof(float));
882 
883 	x1 = (float*)malloc(n_in  * mini_size * sizeof(float));
884 	y1 = (float*)malloc(n_out * mini_size * sizeof(float));
885 	kann_feed_bind(ann, KANN_F_IN,    0, &x1);
886 	kann_feed_bind(ann, KANN_F_TRUTH, 0, &y1);
887 
888 	for (i = 0; i < max_epoch; ++i) {
889 		int n_proc = 0, n_train_err = 0, n_val_err = 0, n_train_base = 0, n_val_base = 0;
890 		double train_cost = 0.0, val_cost = 0.0;
891 		kann_shuffle(n_train, shuf);
892 		kann_switch(ann, 1);
893 		while (n_proc < n_train) {
894 			int b, c, ms = n_train - n_proc < mini_size? n_train - n_proc : mini_size;
895 			for (b = 0; b < ms; ++b) {
896 				memcpy(&x1[b*n_in],  x[shuf[n_proc+b]], n_in  * sizeof(float));
897 				memcpy(&y1[b*n_out], y[shuf[n_proc+b]], n_out * sizeof(float));
898 			}
899 			kann_set_batch_size(ann, ms);
900 			train_cost += kann_cost(ann, 0, 1) * ms;
901 			c = kann_class_error(ann, &b);
902 			n_train_err += c, n_train_base += b;
903 			kann_RMSprop(n_var, lr, 0, 0.9f, ann->g, ann->x, r);
904 			n_proc += ms;
905 		}
906 		train_cost /= n_train;
907 		kann_switch(ann, 0);
908 		n_proc = 0;
909 		while (n_proc < n_val) {
910 			int b, c, ms = n_val - n_proc < mini_size? n_val - n_proc : mini_size;
911 			for (b = 0; b < ms; ++b) {
912 				memcpy(&x1[b*n_in],  x[n_train+n_proc+b], n_in  * sizeof(float));
913 				memcpy(&y1[b*n_out], y[n_train+n_proc+b], n_out * sizeof(float));
914 			}
915 			kann_set_batch_size(ann, ms);
916 			val_cost += kann_cost(ann, 0, 0) * ms;
917 			c = kann_class_error(ann, &b);
918 			n_val_err += c, n_val_base += b;
919 			n_proc += ms;
920 		}
921 		if (n_val > 0) val_cost /= n_val;
922 		if (cb) {
923 			cb(i + 1, train_cost, val_cost, ud);
924 #if 0
925 			fprintf(stderr, "epoch: %d; training cost: %g", i+1, train_cost);
926 			if (n_train_base) fprintf(stderr, " (class error: %.2f%%)", 100.0f * n_train_err / n_train);
927 			if (n_val > 0) {
928 				fprintf(stderr, "; validation cost: %g", val_cost);
929 				if (n_val_base) fprintf(stderr, " (class error: %.2f%%)", 100.0f * n_val_err / n_val);
930 			}
931 			fputc('\n', stderr);
932 #endif
933 		}
934 		if (i >= max_drop_streak && n_val > 0) {
935 			if (val_cost < min_val_cost) {
936 				min_set = 1;
937 				memcpy(min_x, ann->x, n_var * sizeof(float));
938 				memcpy(min_c, ann->c, n_const * sizeof(float));
939 				drop_streak = 0;
940 				min_val_cost = (float)val_cost;
941 			} else if (++drop_streak >= max_drop_streak)
942 				break;
943 		}
944 	}
945 	if (min_set) {
946 		memcpy(ann->x, min_x, n_var * sizeof(float));
947 		memcpy(ann->c, min_c, n_const * sizeof(float));
948 	}
949 
950 	free(min_c); free(min_x); free(y1); free(x1); free(y); free(x); free(shuf); free(r);
951 	return i;
952 }
953 
kann_cost_fnn1(kann_t * ann,int n,float ** x,float ** y)954 float kann_cost_fnn1(kann_t *ann, int n, float **x, float **y)
955 {
956 	int n_in, n_out, n_proc = 0, mini_size = 64 < n? 64 : n;
957 	float *x1, *y1;
958 	double cost = 0.0;
959 
960 	n_in = kann_dim_in(ann);
961 	n_out = kann_dim_out(ann);
962 	if (n <= 0 || n_in < 0 || n_out < 0) return 0.0;
963 
964 	x1 = (float*)malloc(n_in  * mini_size * sizeof(float));
965 	y1 = (float*)malloc(n_out * mini_size * sizeof(float));
966 	kann_feed_bind(ann, KANN_F_IN,    0, &x1);
967 	kann_feed_bind(ann, KANN_F_TRUTH, 0, &y1);
968 	kann_switch(ann, 0);
969 	while (n_proc < n) {
970 		int b, ms = n - n_proc < mini_size? n - n_proc : mini_size;
971 		for (b = 0; b < ms; ++b) {
972 			memcpy(&x1[b*n_in],  x[n_proc+b], n_in  * sizeof(float));
973 			memcpy(&y1[b*n_out], y[n_proc+b], n_out * sizeof(float));
974 		}
975 		kann_set_batch_size(ann, ms);
976 		cost += kann_cost(ann, 0, 0) * ms;
977 		n_proc += ms;
978 	}
979 	free(y1); free(x1);
980 	return (float)(cost / n);
981 }
982 
kann_apply1(kann_t * a,float * x)983 const float *kann_apply1(kann_t *a, float *x)
984 {
985 	int i_out;
986 	i_out = kann_find(a, KANN_F_OUT, 0);
987 	if (i_out < 0) return 0;
988 	kann_set_batch_size(a, 1);
989 	kann_feed_bind(a, KANN_F_IN, 0, &x);
990 	kad_eval_at(a->n, a->v, i_out);
991 	return a->v[i_out]->x;
992 }
993