1 #include "config.h"
2
3 #include <math.h>
4 #include <float.h>
5 #include <string.h>
6 #include <stdlib.h>
7 #include <assert.h>
8 #include <stdarg.h>
9 #include "kann.h"
10
11 int kann_verbose = 3;
12
13 /******************************************
14 *** @@BASIC: fundamental KANN routines ***
15 ******************************************/
16
kad_ext_collate(int n,kad_node_t ** a,float ** _x,float ** _g,float ** _c)17 static void kad_ext_collate(int n, kad_node_t **a, float **_x, float **_g, float **_c)
18 {
19 int i, j, k, l, n_var;
20 float *x, *g, *c;
21 n_var = kad_size_var(n, a);
22 x = *_x = (float*)realloc(*_x, n_var * sizeof(float));
23 g = *_g = (float*)realloc(*_g, n_var * sizeof(float));
24 c = *_c = (float*)realloc(*_c, kad_size_const(n, a) * sizeof(float));
25 memset(g, 0, n_var * sizeof(float));
26 for (i = j = k = 0; i < n; ++i) {
27 kad_node_t *v = a[i];
28 if (kad_is_var(v)) {
29 l = kad_len(v);
30 memcpy(&x[j], v->x, l * sizeof(float));
31 free(v->x);
32 v->x = &x[j];
33 v->g = &g[j];
34 j += l;
35 } else if (kad_is_const(v)) {
36 l = kad_len(v);
37 memcpy(&c[k], v->x, l * sizeof(float));
38 free(v->x);
39 v->x = &c[k];
40 k += l;
41 }
42 }
43 }
44
kad_ext_sync(int n,kad_node_t ** a,float * x,float * g,float * c)45 static void kad_ext_sync(int n, kad_node_t **a, float *x, float *g, float *c)
46 {
47 int i, j, k;
48 for (i = j = k = 0; i < n; ++i) {
49 kad_node_t *v = a[i];
50 if (kad_is_var(v)) {
51 v->x = &x[j];
52 v->g = &g[j];
53 j += kad_len(v);
54 } else if (kad_is_const(v)) {
55 v->x = &c[k];
56 k += kad_len(v);
57 }
58 }
59 }
60
kann_new(kad_node_t * cost,int n_rest,...)61 kann_t *kann_new(kad_node_t *cost, int n_rest, ...)
62 {
63 kann_t *a;
64 int i, n_roots = 1 + n_rest, has_pivot = 0, has_recur = 0;
65 kad_node_t **roots;
66 va_list ap;
67
68 if (cost->n_d != 0) return 0;
69
70 va_start(ap, n_rest);
71 roots = (kad_node_t**)malloc((n_roots + 1) * sizeof(kad_node_t*));
72 for (i = 0; i < n_rest; ++i)
73 roots[i] = va_arg(ap, kad_node_t*);
74 roots[i++] = cost;
75 va_end(ap);
76
77 cost->ext_flag |= KANN_F_COST;
78 a = (kann_t*)calloc(1, sizeof(kann_t));
79 a->v = kad_compile_array(&a->n, n_roots, roots);
80
81 for (i = 0; i < a->n; ++i) {
82 if (a->v[i]->pre) has_recur = 1;
83 if (kad_is_pivot(a->v[i])) has_pivot = 1;
84 }
85 if (has_recur && !has_pivot) { /* an RNN that doesn't have a pivot; then add a pivot on top of cost and recompile */
86 cost->ext_flag &= ~KANN_F_COST;
87 roots[n_roots-1] = cost = kad_avg(1, &cost), cost->ext_flag |= KANN_F_COST;
88 free(a->v);
89 a->v = kad_compile_array(&a->n, n_roots, roots);
90 }
91 kad_ext_collate(a->n, a->v, &a->x, &a->g, &a->c);
92 free(roots);
93 return a;
94 }
95
kann_clone(kann_t * a,int batch_size)96 kann_t *kann_clone(kann_t *a, int batch_size)
97 {
98 kann_t *b;
99 b = (kann_t*)calloc(1, sizeof(kann_t));
100 b->n = a->n;
101 b->v = kad_clone(a->n, a->v, batch_size);
102 kad_ext_collate(b->n, b->v, &b->x, &b->g, &b->c);
103 return b;
104 }
105
kann_unroll_array(kann_t * a,int * len)106 kann_t *kann_unroll_array(kann_t *a, int *len)
107 {
108 kann_t *b;
109 b = (kann_t*)calloc(1, sizeof(kann_t));
110 b->x = a->x, b->g = a->g, b->c = a->c; /* these arrays are shared */
111 b->v = kad_unroll(a->n, a->v, &b->n, len);
112 return b;
113 }
114
kann_unroll(kann_t * a,...)115 kann_t *kann_unroll(kann_t *a, ...)
116 {
117 kann_t *b;
118 va_list ap;
119 int i, n_pivots, *len;
120 n_pivots = kad_n_pivots(a->n, a->v);
121 len = (int*)calloc(n_pivots, sizeof(int));
122 va_start(ap, a);
123 for (i = 0; i < n_pivots; ++i) len[i] = va_arg(ap, int);
124 va_end(ap);
125 b = kann_unroll_array(a, len);
126 free(len);
127 return b;
128 }
129
kann_delete_unrolled(kann_t * a)130 void kann_delete_unrolled(kann_t *a)
131 {
132 if (a && a->mt) kann_mt(a, 0, 0);
133 if (a && a->v) kad_delete(a->n, a->v);
134 free(a);
135 }
136
kann_delete(kann_t * a)137 void kann_delete(kann_t *a)
138 {
139 if (a == 0) return;
140 free(a->x); free(a->g); free(a->c);
141 kann_delete_unrolled(a);
142 }
143
kann_switch_core(kann_t * a,int is_train)144 static void kann_switch_core(kann_t *a, int is_train)
145 {
146 int i;
147 for (i = 0; i < a->n; ++i)
148 if (a->v[i]->op == 12 && a->v[i]->n_child == 2)
149 *(int32_t*)a->v[i]->ptr = !!is_train;
150 }
151
152 #define chk_flg(flag, mask) ((mask) == 0 || ((flag) & (mask)))
153 #define chk_lbl(label, query) ((query) == 0 || (label) == (query))
154
kann_find(const kann_t * a,uint32_t ext_flag,int32_t ext_label)155 int kann_find(const kann_t *a, uint32_t ext_flag, int32_t ext_label)
156 {
157 int i, k, r = -1;
158 for (i = k = 0; i < a->n; ++i)
159 if (chk_flg(a->v[i]->ext_flag, ext_flag) && chk_lbl(a->v[i]->ext_label, ext_label))
160 ++k, r = i;
161 return k == 1? r : k == 0? -1 : -2;
162 }
163
kann_feed_bind(kann_t * a,uint32_t ext_flag,int32_t ext_label,float ** x)164 int kann_feed_bind(kann_t *a, uint32_t ext_flag, int32_t ext_label, float **x)
165 {
166 int i, k;
167 if (x == 0) return 0;
168 for (i = k = 0; i < a->n; ++i)
169 if (kad_is_feed(a->v[i]) && chk_flg(a->v[i]->ext_flag, ext_flag) && chk_lbl(a->v[i]->ext_label, ext_label))
170 a->v[i]->x = x[k++];
171 return k;
172 }
173
kann_feed_dim(const kann_t * a,uint32_t ext_flag,int32_t ext_label)174 int kann_feed_dim(const kann_t *a, uint32_t ext_flag, int32_t ext_label)
175 {
176 int i, k, n = 0;
177 for (i = k = 0; i < a->n; ++i)
178 if (kad_is_feed(a->v[i]) && chk_flg(a->v[i]->ext_flag, ext_flag) && chk_lbl(a->v[i]->ext_label, ext_label))
179 ++k, n = a->v[i]->n_d > 1? kad_len(a->v[i]) / a->v[i]->d[0] : a->v[i]->n_d == 1? a->v[i]->d[0] : 1;
180 return k == 1? n : k == 0? -1 : -2;
181 }
182
kann_cost_core(kann_t * a,int cost_label,int cal_grad)183 static float kann_cost_core(kann_t *a, int cost_label, int cal_grad)
184 {
185 int i_cost;
186 float cost;
187 i_cost = kann_find(a, KANN_F_COST, cost_label);
188 assert(i_cost >= 0);
189 cost = *kad_eval_at(a->n, a->v, i_cost);
190 if (cal_grad) kad_grad(a->n, a->v, i_cost);
191 return cost;
192 }
193
kann_eval(kann_t * a,uint32_t ext_flag,int ext_label)194 int kann_eval(kann_t *a, uint32_t ext_flag, int ext_label)
195 {
196 int i, k;
197 for (i = k = 0; i < a->n; ++i)
198 if (chk_flg(a->v[i]->ext_flag, ext_flag) && chk_lbl(a->v[i]->ext_label, ext_label))
199 ++k, a->v[i]->tmp = 1;
200 kad_eval_marked(a->n, a->v);
201 return k;
202 }
203
kann_rnn_start(kann_t * a)204 void kann_rnn_start(kann_t *a)
205 {
206 int i;
207 kann_set_batch_size(a, 1);
208 for (i = 0; i < a->n; ++i) {
209 kad_node_t *p = a->v[i];
210 if (p->pre) { /* NB: BE CAREFUL of the interaction between kann_rnn_start() and kann_set_batch_size() */
211 kad_node_t *q = p->pre;
212 if (q->x) memcpy(p->x, q->x, kad_len(p) * sizeof(float));
213 else memset(p->x, 0, kad_len(p) * sizeof(float));
214 if (q->n_child > 0) free(q->x);
215 q->x = p->x;
216 }
217 }
218 }
219
kann_rnn_end(kann_t * a)220 void kann_rnn_end(kann_t *a)
221 {
222 int i;
223 kad_ext_sync(a->n, a->v, a->x, a->g, a->c);
224 for (i = 0; i < a->n; ++i)
225 if (a->v[i]->pre && a->v[i]->pre->n_child > 0)
226 a->v[i]->pre->x = (float*)calloc(kad_len(a->v[i]->pre), sizeof(float));
227 }
228
kann_class_error_core(const kann_t * ann,int * base)229 static int kann_class_error_core(const kann_t *ann, int *base)
230 {
231 int i, j, k, m, n, off, n_err = 0;
232 for (i = 0, *base = 0; i < ann->n; ++i) {
233 kad_node_t *p = ann->v[i];
234 if (((p->op == 13 && (p->n_child == 2 || p->n_child == 3)) || (p->op == 22 && p->n_child == 2)) && p->n_d == 0) { /* ce_bin or ce_multi */
235 kad_node_t *x = p->child[0], *t = p->child[1];
236 n = t->d[t->n_d - 1], m = kad_len(t) / n;
237 for (j = off = 0; j < m; ++j, off += n) {
238 float t_sum = 0.0f, t_min = 1.0f, t_max = 0.0f, x_max = 0.0f, x_min = 1.0f;
239 int x_max_k = -1, t_max_k = -1;
240 for (k = 0; k < n; ++k) {
241 float xk = x->x[off+k], tk = t->x[off+k];
242 t_sum += tk;
243 t_min = t_min < tk? t_min : tk;
244 x_min = x_min < xk? x_min : xk;
245 if (t_max < tk) t_max = tk, t_max_k = k;
246 if (x_max < xk) x_max = xk, x_max_k = k;
247 }
248 if (t_sum - 1.0f == 0 && t_min >= 0.0f && x_min >= 0.0f && x_max <= 1.0f) {
249 ++(*base);
250 n_err += (x_max_k != t_max_k);
251 }
252 }
253 }
254 }
255 return n_err;
256 }
257
258 /*************************
259 * @@MT: multi-threading *
260 *************************/
261
262 #ifdef HAVE_PTHREAD
263 #include <pthread.h>
264
265 struct mtaux_t;
266
267 typedef struct { /* per-worker data */
268 kann_t *a;
269 float cost;
270 int action;
271 pthread_t tid;
272 struct mtaux_t *g;
273 } mtaux1_t;
274
275 typedef struct mtaux_t { /* cross-worker data */
276 int n_threads, max_batch_size;
277 int cal_grad, cost_label, eval_out;
278 volatile int n_idle; /* we will be busy waiting on this, so volatile necessary */
279 pthread_mutex_t mtx;
280 pthread_cond_t cv;
281 mtaux1_t *mt;
282 } mtaux_t;
283
mt_worker(void * data)284 static void *mt_worker(void *data) /* pthread worker */
285 {
286 mtaux1_t *mt1 = (mtaux1_t*)data;
287 mtaux_t *mt = mt1->g;
288 for (;;) {
289 int action;
290 pthread_mutex_lock(&mt->mtx);
291 mt1->action = 0;
292 ++mt->n_idle;
293 while (mt1->action == 0)
294 pthread_cond_wait(&mt->cv, &mt->mtx);
295 action = mt1->action;
296 pthread_mutex_unlock(&mt->mtx);
297 if (action == -1) break;
298
299 if (mt->eval_out) kann_eval(mt1->a, KANN_F_OUT, 0);
300 else mt1->cost = kann_cost_core(mt1->a, mt->cost_label, mt->cal_grad);
301 }
302 pthread_exit(0);
303 }
304
mt_destroy(mtaux_t * mt)305 static void mt_destroy(mtaux_t *mt) /* de-allocate an entire mtaux_t struct */
306 {
307 int i;
308 pthread_mutex_lock(&mt->mtx);
309 mt->n_idle = 0;
310 for (i = 1; i < mt->n_threads; ++i) mt->mt[i].action = -1;
311 pthread_cond_broadcast(&mt->cv);
312 pthread_mutex_unlock(&mt->mtx);
313 for (i = 1; i < mt->n_threads; ++i) pthread_join(mt->mt[i].tid, 0);
314 for (i = 0; i < mt->n_threads; ++i) kann_delete(mt->mt[i].a);
315 free(mt->mt);
316 pthread_cond_destroy(&mt->cv);
317 pthread_mutex_destroy(&mt->mtx);
318 free(mt);
319 }
320
kann_mt(kann_t * ann,int n_threads,int max_batch_size)321 void kann_mt(kann_t *ann, int n_threads, int max_batch_size)
322 {
323 mtaux_t *mt;
324 int i, k;
325
326 if (n_threads <= 1) {
327 if (ann->mt) mt_destroy((mtaux_t*)ann->mt);
328 ann->mt = 0;
329 return;
330 }
331 if (n_threads > max_batch_size) n_threads = max_batch_size;
332 if (n_threads <= 1) return;
333
334 mt = (mtaux_t*)calloc(1, sizeof(mtaux_t));
335 mt->n_threads = n_threads, mt->max_batch_size = max_batch_size;
336 pthread_mutex_init(&mt->mtx, 0);
337 pthread_cond_init(&mt->cv, 0);
338 mt->mt = (mtaux1_t*)calloc(n_threads, sizeof(mtaux1_t));
339 for (i = k = 0; i < n_threads; ++i) {
340 int size = (max_batch_size - k) / (n_threads - i);
341 mt->mt[i].a = kann_clone(ann, size);
342 mt->mt[i].g = mt;
343 k += size;
344 }
345 for (i = 1; i < n_threads; ++i)
346 pthread_create(&mt->mt[i].tid, 0, mt_worker, &mt->mt[i]);
347 while (mt->n_idle < n_threads - 1); /* busy waiting until all threads in sync */
348 ann->mt = mt;
349 }
350
mt_kickoff(kann_t * a,int cost_label,int cal_grad,int eval_out)351 static void mt_kickoff(kann_t *a, int cost_label, int cal_grad, int eval_out)
352 {
353 mtaux_t *mt = (mtaux_t*)a->mt;
354 int i, j, k, B, n_var;
355
356 B = kad_sync_dim(a->n, a->v, -1); /* get the current batch size */
357 assert(B <= mt->max_batch_size); /* TODO: can be relaxed */
358 n_var = kann_size_var(a);
359
360 pthread_mutex_lock(&mt->mtx);
361 mt->cost_label = cost_label, mt->cal_grad = cal_grad, mt->eval_out = eval_out;
362 for (i = k = 0; i < mt->n_threads; ++i) {
363 int size = (B - k) / (mt->n_threads - i);
364 for (j = 0; j < a->n; ++j)
365 if (kad_is_feed(a->v[j]))
366 mt->mt[i].a->v[j]->x = &a->v[j]->x[k * kad_len(a->v[j]) / a->v[j]->d[0]];
367 kad_sync_dim(mt->mt[i].a->n, mt->mt[i].a->v, size); /* TODO: we can point ->x to internal nodes, too */
368 k += size;
369 memcpy(mt->mt[i].a->x, a->x, n_var * sizeof(float));
370 mt->mt[i].action = 1;
371 }
372 mt->n_idle = 0;
373 pthread_cond_broadcast(&mt->cv);
374 pthread_mutex_unlock(&mt->mtx);
375 }
376
kann_cost(kann_t * a,int cost_label,int cal_grad)377 float kann_cost(kann_t *a, int cost_label, int cal_grad)
378 {
379 mtaux_t *mt = (mtaux_t*)a->mt;
380 int i, j, B, k, n_var;
381 float cost;
382
383 if (mt == 0) return kann_cost_core(a, cost_label, cal_grad);
384 B = kad_sync_dim(a->n, a->v, -1); /* get the current batch size */
385 n_var = kann_size_var(a);
386
387 mt_kickoff(a, cost_label, cal_grad, 0);
388 mt->mt[0].cost = kann_cost_core(mt->mt[0].a, cost_label, cal_grad);
389 while (mt->n_idle < mt->n_threads - 1); /* busy waiting until all threads in sync */
390
391 memset(a->g, 0, n_var * sizeof(float)); /* TODO: check if this is necessary when cal_grad is false */
392 for (i = k = 0, cost = 0.0f; i < mt->n_threads; ++i) {
393 int size = (B - k) / (mt->n_threads - i);
394 cost += mt->mt[i].cost * size / B;
395 kad_saxpy(n_var, (float)size / B, mt->mt[i].a->g, a->g);
396 k += size;
397 }
398 for (j = 0; j < a->n; ++j) { /* copy values back at recurrent nodes (needed by textgen; TODO: temporary solution) */
399 kad_node_t *p = a->v[j];
400 if (p->pre && p->n_d >= 2 && p->d[0] == B) {
401 for (i = k = 0; i < mt->n_threads; ++i) {
402 kad_node_t *q = mt->mt[i].a->v[j];
403 memcpy(&p->x[k], q->x, kad_len(q) * sizeof(float));
404 k += kad_len(q);
405 }
406 }
407 }
408 return cost;
409 }
410
kann_eval_out(kann_t * a)411 int kann_eval_out(kann_t *a)
412 {
413 mtaux_t *mt = (mtaux_t*)a->mt;
414 int j, B, n_eval;
415 if (mt == 0) return kann_eval(a, KANN_F_OUT, 0);
416 B = kad_sync_dim(a->n, a->v, -1); /* get the current batch size */
417 mt_kickoff(a, 0, 0, 1);
418 n_eval = kann_eval(mt->mt[0].a, KANN_F_OUT, 0);
419 while (mt->n_idle < mt->n_threads - 1); /* busy waiting until all threads in sync */
420 for (j = 0; j < a->n; ++j) { /* copy output values back */
421 kad_node_t *p = a->v[j];
422 if (p->ext_flag & KANN_F_OUT) {
423 int i, t, k, d0 = p->d[0] / B, d1 = 1; /* for RNN, p->d[0] may equal unroll_len * batch_size */
424 assert(p->d[0] % B == 0);
425 for (i = 1; i < p->n_d; ++i) d1 *= p->d[i];
426 for (i = 0; i < d0; ++i) {
427 for (t = k = 0; t < mt->n_threads; ++t) { /* similar to the forward pass of kad_op_concat() */
428 kad_node_t *q = mt->mt[t].a->v[j];
429 int size = q->d[0] / d0;
430 memcpy(&p->x[(i * B + k) * d1], &q->x[i * size * d1], size * d1 * sizeof(float));
431 k += size;
432 }
433 }
434 }
435 }
436 return n_eval;
437 }
438
kann_class_error(const kann_t * ann,int * base)439 int kann_class_error(const kann_t *ann, int *base)
440 {
441 mtaux_t *mt = (mtaux_t*)ann->mt;
442 int i, n_err = 0, b = 0;
443 if (mt == 0) return kann_class_error_core(ann, base);
444 for (i = 0; i < mt->n_threads; ++i) {
445 n_err += kann_class_error_core(mt->mt[i].a, &b);
446 *base += b;
447 }
448 return n_err;
449 }
450
kann_switch(kann_t * ann,int is_train)451 void kann_switch(kann_t *ann, int is_train)
452 {
453 mtaux_t *mt = (mtaux_t*)ann->mt;
454 int i;
455 if (mt == 0) {
456 kann_switch_core(ann, is_train);
457 return;
458 }
459 for (i = 0; i < mt->n_threads; ++i)
460 kann_switch_core(mt->mt[i].a, is_train);
461 }
462 #else
kann_mt(kann_t * ann,int n_threads,int max_batch_size)463 void kann_mt(kann_t *ann, int n_threads, int max_batch_size) {}
kann_cost(kann_t * a,int cost_label,int cal_grad)464 float kann_cost(kann_t *a, int cost_label, int cal_grad) { return kann_cost_core(a, cost_label, cal_grad); }
kann_eval_out(kann_t * a)465 int kann_eval_out(kann_t *a) { return kann_eval(a, KANN_F_OUT, 0); }
kann_class_error(const kann_t * a,int * base)466 int kann_class_error(const kann_t *a, int *base) { return kann_class_error_core(a, base); }
kann_switch(kann_t * ann,int is_train)467 void kann_switch(kann_t *ann, int is_train) { return kann_switch_core(ann, is_train); }
468 #endif
469
470 /***********************
471 *** @@IO: model I/O ***
472 ***********************/
473
474 #define KANN_MAGIC "KAN\1"
475
kann_save_fp(FILE * fp,kann_t * ann)476 void kann_save_fp(FILE *fp, kann_t *ann)
477 {
478 kann_set_batch_size(ann, 1);
479 fwrite(KANN_MAGIC, 1, 4, fp);
480 kad_save(fp, ann->n, ann->v);
481 fwrite(ann->x, sizeof(float), kann_size_var(ann), fp);
482 fwrite(ann->c, sizeof(float), kann_size_const(ann), fp);
483 }
484
kann_save(const char * fn,kann_t * ann)485 void kann_save(const char *fn, kann_t *ann)
486 {
487 FILE *fp;
488 fp = fn && strcmp(fn, "-")? fopen(fn, "wb") : stdout;
489 kann_save_fp(fp, ann);
490 fclose(fp);
491 }
492
kann_load_fp(FILE * fp)493 kann_t *kann_load_fp(FILE *fp)
494 {
495 char magic[4];
496 kann_t *ann;
497 int n_var, n_const;
498
499 (void) !fread(magic, 1, 4, fp);
500 if (strncmp(magic, KANN_MAGIC, 4) != 0) {
501 return 0;
502 }
503 ann = (kann_t*)calloc(1, sizeof(kann_t));
504 ann->v = kad_load(fp, &ann->n);
505 n_var = kad_size_var(ann->n, ann->v);
506 n_const = kad_size_const(ann->n, ann->v);
507 ann->x = (float*)malloc(n_var * sizeof(float));
508 ann->g = (float*)calloc(n_var, sizeof(float));
509 ann->c = (float*)malloc(n_const * sizeof(float));
510 (void) !fread(ann->x, sizeof(float), n_var, fp);
511 (void) !fread(ann->c, sizeof(float), n_const, fp);
512 kad_ext_sync(ann->n, ann->v, ann->x, ann->g, ann->c);
513 return ann;
514 }
515
kann_load(const char * fn)516 kann_t *kann_load(const char *fn)
517 {
518 FILE *fp;
519 kann_t *ann;
520 fp = fn && strcmp(fn, "-")? fopen(fn, "rb") : stdin;
521 ann = kann_load_fp(fp);
522 fclose(fp);
523 return ann;
524 }
525
526 /**********************************************
527 *** @@LAYER: layers and model generation ***
528 **********************************************/
529
530 /********** General but more complex APIs **********/
531
kann_new_leaf_array(int * offset,kad_node_p * par,uint8_t flag,float x0_01,int n_d,int32_t d[KAD_MAX_DIM])532 kad_node_t *kann_new_leaf_array(int *offset, kad_node_p *par, uint8_t flag, float x0_01, int n_d, int32_t d[KAD_MAX_DIM])
533 {
534 int i, len, off = offset && par? *offset : -1;
535 kad_node_t *p;
536
537 if (off >= 0 && par[off]) return par[(*offset)++];
538 p = (kad_node_t*)calloc(1, sizeof(kad_node_t));
539 p->n_d = n_d, p->flag = flag;
540 memcpy(p->d, d, n_d * sizeof(int32_t));
541 len = kad_len(p);
542 p->x = (float*)calloc(len, sizeof(float));
543 if (p->n_d <= 1) {
544 for (i = 0; i < len; ++i)
545 p->x[i] = x0_01;
546 } else {
547 double sdev_inv;
548 sdev_inv = 1.0 / sqrt((double)len / p->d[0]);
549 for (i = 0; i < len; ++i)
550 p->x[i] = (float)(kad_drand_normal(0) * sdev_inv);
551 }
552 if (off >= 0) par[off] = p, ++(*offset);
553 return p;
554 }
555
kann_new_leaf2(int * offset,kad_node_p * par,uint8_t flag,float x0_01,int n_d,...)556 kad_node_t *kann_new_leaf2(int *offset, kad_node_p *par, uint8_t flag, float x0_01, int n_d, ...)
557 {
558 int32_t i, d[KAD_MAX_DIM];
559 va_list ap;
560 va_start(ap, n_d); for (i = 0; i < n_d; ++i) d[i] = va_arg(ap, int); va_end(ap);
561 return kann_new_leaf_array(offset, par, flag, x0_01, n_d, d);
562 }
563
kann_layer_dense2(int * offset,kad_node_p * par,kad_node_t * in,int n1)564 kad_node_t *kann_layer_dense2(int *offset, kad_node_p *par, kad_node_t *in, int n1)
565 {
566 int n0;
567 kad_node_t *w, *b;
568 n0 = in->n_d >= 2? kad_len(in) / in->d[0] : kad_len(in);
569 w = kann_new_leaf2(offset, par, KAD_VAR, 0.0f, 2, n1, n0);
570 b = kann_new_leaf2(offset, par, KAD_VAR, 0.0f, 1, n1);
571 return kad_add(kad_cmul(in, w), b);
572 }
573
kann_layer_dropout2(int * offset,kad_node_p * par,kad_node_t * t,float r)574 kad_node_t *kann_layer_dropout2(int *offset, kad_node_p *par, kad_node_t *t, float r)
575 {
576 kad_node_t *x[2], *cr;
577 cr = kann_new_leaf2(offset, par, KAD_CONST, r, 0);
578 x[0] = t, x[1] = kad_dropout(t, cr);
579 return kad_switch(2, x);
580 }
581
kann_layer_layernorm2(int * offset,kad_node_t ** par,kad_node_t * in)582 kad_node_t *kann_layer_layernorm2(int *offset, kad_node_t **par, kad_node_t *in)
583 {
584 int n0;
585 kad_node_t *alpha, *beta;
586 n0 = in->n_d >= 2? kad_len(in) / in->d[0] : kad_len(in);
587 alpha = kann_new_leaf2(offset, par, KAD_VAR, 1.0f, 1, n0);
588 beta = kann_new_leaf2(offset, par, KAD_VAR, 0.0f, 1, n0);
589 return kad_add(kad_mul(kad_stdnorm(in), alpha), beta);
590 }
591
cmul_norm2(int * offset,kad_node_t ** par,kad_node_t * x,kad_node_t * w,int use_norm)592 static inline kad_node_t *cmul_norm2(int *offset, kad_node_t **par, kad_node_t *x, kad_node_t *w, int use_norm)
593 {
594 return use_norm? kann_layer_layernorm2(offset, par, kad_cmul(x, w)) : kad_cmul(x, w);
595 }
596
kann_layer_rnn2(int * offset,kad_node_t ** par,kad_node_t * in,kad_node_t * h0,int rnn_flag)597 kad_node_t *kann_layer_rnn2(int *offset, kad_node_t **par, kad_node_t *in, kad_node_t *h0, int rnn_flag)
598 {
599 int n0, n1 = h0->d[h0->n_d-1], use_norm = !!(rnn_flag & KANN_RNN_NORM);
600 kad_node_t *t, *w, *u, *b, *out;
601
602 u = kann_new_leaf2(offset, par, KAD_VAR, 0.0f, 2, n1, n1);
603 b = kann_new_leaf2(offset, par, KAD_VAR, 0.0f, 1, n1);
604 t = cmul_norm2(offset, par, h0, u, use_norm);
605 if (in) {
606 n0 = in->n_d >= 2? kad_len(in) / in->d[0] : kad_len(in);
607 w = kann_new_leaf2(offset, par, KAD_VAR, 0.0f, 2, n1, n0);
608 t = kad_add(cmul_norm2(offset, par, in, w, use_norm), t);
609 }
610 out = kad_tanh(kad_add(t, b));
611 out->pre = h0;
612 return out;
613 }
614
kann_layer_gru2(int * offset,kad_node_t ** par,kad_node_t * in,kad_node_t * h0,int rnn_flag)615 kad_node_t *kann_layer_gru2(int *offset, kad_node_t **par, kad_node_t *in, kad_node_t *h0, int rnn_flag)
616 {
617 int n0 = 0, n1 = h0->d[h0->n_d-1], use_norm = !!(rnn_flag & KANN_RNN_NORM);
618 kad_node_t *t, *r, *z, *w, *u, *b, *s, *out;
619
620 if (in) n0 = in->n_d >= 2? kad_len(in) / in->d[0] : kad_len(in);
621 /* z = sigm(x_t * W_z + h_{t-1} * U_z + b_z) */
622 u = kann_new_leaf2(offset, par, KAD_VAR, 0.0f, 2, n1, n1);
623 b = kann_new_leaf2(offset, par, KAD_VAR, 0.0f, 1, n1);
624 t = cmul_norm2(offset, par, h0, u, use_norm);
625 if (in) {
626 w = kann_new_leaf2(offset, par, KAD_VAR, 0.0f, 2, n1, n0);
627 t = kad_add(cmul_norm2(offset, par, in, w, use_norm), t);
628 }
629 z = kad_sigm(kad_add(t, b));
630 /* r = sigm(x_t * W_r + h_{t-1} * U_r + b_r) */
631 u = kann_new_leaf2(offset, par, KAD_VAR, 0.0f, 2, n1, n1);
632 b = kann_new_leaf2(offset, par, KAD_VAR, 0.0f, 1, n1);
633 t = cmul_norm2(offset, par, h0, u, use_norm);
634 if (in) {
635 w = kann_new_leaf2(offset, par, KAD_VAR, 0.0f, 2, n1, n0);
636 t = kad_add(cmul_norm2(offset, par, in, w, use_norm), t);
637 }
638 r = kad_sigm(kad_add(t, b));
639 /* s = tanh(x_t * W_s + (h_{t-1} # r) * U_s + b_s) */
640 u = kann_new_leaf2(offset, par, KAD_VAR, 0.0f, 2, n1, n1);
641 b = kann_new_leaf2(offset, par, KAD_VAR, 0.0f, 1, n1);
642 t = cmul_norm2(offset, par, kad_mul(r, h0), u, use_norm);
643 if (in) {
644 w = kann_new_leaf2(offset, par, KAD_VAR, 0.0f, 2, n1, n0);
645 t = kad_add(cmul_norm2(offset, par, in, w, use_norm), t);
646 }
647 s = kad_tanh(kad_add(t, b));
648 /* h_t = z # h_{t-1} + (1 - z) # s */
649 out = kad_add(kad_mul(kad_1minus(z), s), kad_mul(z, h0));
650 out->pre = h0;
651 return out;
652 }
653
654 /********** APIs without offset & par **********/
655
kann_new_leaf(uint8_t flag,float x0_01,int n_d,...)656 kad_node_t *kann_new_leaf(uint8_t flag, float x0_01, int n_d, ...)
657 {
658 int32_t i, d[KAD_MAX_DIM];
659 va_list ap;
660 va_start(ap, n_d); for (i = 0; i < n_d; ++i) d[i] = va_arg(ap, int); va_end(ap);
661 return kann_new_leaf_array(0, 0, flag, x0_01, n_d, d);
662 }
663
kann_new_scalar(uint8_t flag,float x)664 kad_node_t *kann_new_scalar(uint8_t flag, float x) { return kann_new_leaf(flag, x, 0); }
kann_new_weight(int n_row,int n_col)665 kad_node_t *kann_new_weight(int n_row, int n_col) { return kann_new_leaf(KAD_VAR, 0.0f, 2, n_row, n_col); }
kann_new_vec(int n,float x)666 kad_node_t *kann_new_vec(int n, float x) { return kann_new_leaf(KAD_VAR, x, 1, n); }
kann_new_bias(int n)667 kad_node_t *kann_new_bias(int n) { return kann_new_vec(n, 0.0f); }
kann_new_weight_conv2d(int n_out,int n_in,int k_row,int k_col)668 kad_node_t *kann_new_weight_conv2d(int n_out, int n_in, int k_row, int k_col) { return kann_new_leaf(KAD_VAR, 0.0f, 4, n_out, n_in, k_row, k_col); }
kann_new_weight_conv1d(int n_out,int n_in,int kernel_len)669 kad_node_t *kann_new_weight_conv1d(int n_out, int n_in, int kernel_len) { return kann_new_leaf(KAD_VAR, 0.0f, 3, n_out, n_in, kernel_len); }
670
kann_layer_input(int n1)671 kad_node_t *kann_layer_input(int n1)
672 {
673 kad_node_t *t;
674 t = kad_feed(2, 1, n1);
675 t->ext_flag |= KANN_F_IN;
676 return t;
677 }
678
kann_layer_dense(kad_node_t * in,int n1)679 kad_node_t *kann_layer_dense(kad_node_t *in, int n1) { return kann_layer_dense2(0, 0, in, n1); }
kann_layer_dropout(kad_node_t * t,float r)680 kad_node_t *kann_layer_dropout(kad_node_t *t, float r) { return kann_layer_dropout2(0, 0, t, r); }
kann_layer_layernorm(kad_node_t * in)681 kad_node_t *kann_layer_layernorm(kad_node_t *in) { return kann_layer_layernorm2(0, 0, in); }
682
kann_layer_rnn(kad_node_t * in,int n1,int rnn_flag)683 kad_node_t *kann_layer_rnn(kad_node_t *in, int n1, int rnn_flag)
684 {
685 kad_node_t *h0;
686 h0 = (rnn_flag & KANN_RNN_VAR_H0)? kad_var(0, 0, 2, 1, n1) : kad_const(0, 2, 1, n1);
687 h0->x = (float*)calloc(n1, sizeof(float));
688 return kann_layer_rnn2(0, 0, in, h0, rnn_flag);
689 }
690
kann_layer_gru(kad_node_t * in,int n1,int rnn_flag)691 kad_node_t *kann_layer_gru(kad_node_t *in, int n1, int rnn_flag)
692 {
693 kad_node_t *h0;
694 h0 = (rnn_flag & KANN_RNN_VAR_H0)? kad_var(0, 0, 2, 1, n1) : kad_const(0, 2, 1, n1);
695 h0->x = (float*)calloc(n1, sizeof(float));
696 return kann_layer_gru2(0, 0, in, h0, rnn_flag);
697 }
698
kann_cmul_norm(kad_node_t * x,kad_node_t * w)699 static kad_node_t *kann_cmul_norm(kad_node_t *x, kad_node_t *w)
700 {
701 return kann_layer_layernorm(kad_cmul(x, w));
702 }
703
kann_layer_lstm(kad_node_t * in,int n1,int rnn_flag)704 kad_node_t *kann_layer_lstm(kad_node_t *in, int n1, int rnn_flag)
705 {
706 int n0;
707 kad_node_t *i, *f, *o, *g, *w, *u, *b, *h0, *c0, *c, *out;
708 kad_node_t *(*cmul)(kad_node_t*, kad_node_t*) = (rnn_flag & KANN_RNN_NORM)? kann_cmul_norm : kad_cmul;
709
710 n0 = in->n_d >= 2? kad_len(in) / in->d[0] : kad_len(in);
711 h0 = (rnn_flag & KANN_RNN_VAR_H0)? kad_var(0, 0, 2, 1, n1) : kad_const(0, 2, 1, n1);
712 h0->x = (float*)calloc(n1, sizeof(float));
713 c0 = (rnn_flag & KANN_RNN_VAR_H0)? kad_var(0, 0, 2, 1, n1) : kad_const(0, 2, 1, n1);
714 c0->x = (float*)calloc(n1, sizeof(float));
715
716 /* i = sigm(x_t * W_i + h_{t-1} * U_i + b_i) */
717 w = kann_new_weight(n1, n0);
718 u = kann_new_weight(n1, n1);
719 b = kann_new_bias(n1);
720 i = kad_sigm(kad_add(kad_add(cmul(in, w), cmul(h0, u)), b));
721 /* f = sigm(x_t * W_f + h_{t-1} * U_f + b_f) */
722 w = kann_new_weight(n1, n0);
723 u = kann_new_weight(n1, n1);
724 b = kann_new_vec(n1, 1.0f); /* see Jozefowicz et al on using a large bias */
725 f = kad_sigm(kad_add(kad_add(cmul(in, w), cmul(h0, u)), b));
726 /* o = sigm(x_t * W_o + h_{t-1} * U_o + b_o) */
727 w = kann_new_weight(n1, n0);
728 u = kann_new_weight(n1, n1);
729 b = kann_new_bias(n1);
730 o = kad_sigm(kad_add(kad_add(cmul(in, w), cmul(h0, u)), b));
731 /* g = tanh(x_t * W_g + h_{t-1} * U_g + b_g) */
732 w = kann_new_weight(n1, n0);
733 u = kann_new_weight(n1, n1);
734 b = kann_new_bias(n1);
735 g = kad_tanh(kad_add(kad_add(cmul(in, w), cmul(h0, u)), b));
736 /* c_t = c_{t-1} # f + g # i */
737 c = kad_add(kad_mul(f, c0), kad_mul(g, i)); /* can't be kad_mul(c0, f)!!! */
738 c->pre = c0;
739 /* h_t = tanh(c_t) # o */
740 if (rnn_flag & KANN_RNN_NORM) c = kann_layer_layernorm(c); /* see Ba et al (2016) about how to apply layer normalization to LSTM */
741 out = kad_mul(kad_tanh(c), o);
742 out->pre = h0;
743 return out;
744 }
745
kann_layer_conv2d(kad_node_t * in,int n_flt,int k_rows,int k_cols,int stride_r,int stride_c,int pad_r,int pad_c)746 kad_node_t *kann_layer_conv2d(kad_node_t *in, int n_flt, int k_rows, int k_cols, int stride_r, int stride_c, int pad_r, int pad_c)
747 {
748 kad_node_t *w;
749 w = kann_new_weight_conv2d(n_flt, in->d[1], k_rows, k_cols);
750 return kad_conv2d(in, w, stride_r, stride_c, pad_r, pad_c);
751 }
752
kann_layer_conv1d(kad_node_t * in,int n_flt,int k_size,int stride,int pad)753 kad_node_t *kann_layer_conv1d(kad_node_t *in, int n_flt, int k_size, int stride, int pad)
754 {
755 kad_node_t *w;
756 w = kann_new_weight_conv1d(n_flt, in->d[1], k_size);
757 return kad_conv1d(in, w, stride, pad);
758 }
759
kann_layer_cost(kad_node_t * t,int n_out,int cost_type)760 kad_node_t *kann_layer_cost(kad_node_t *t, int n_out, int cost_type)
761 {
762 kad_node_t *cost = 0, *truth = 0;
763 assert(cost_type == KANN_C_CEB || cost_type == KANN_C_CEM || cost_type == KANN_C_CEB_NEG || cost_type == KANN_C_MSE);
764 t = kann_layer_dense(t, n_out);
765 truth = kad_feed(2, 1, n_out), truth->ext_flag |= KANN_F_TRUTH;
766
767 if (cost_type == KANN_C_MSE) {
768 cost = kad_mse(t, truth);
769 } else if (cost_type == KANN_C_CEB) {
770 t = kad_sigm(t);
771 cost = kad_ce_bin(t, truth);
772 } else if (cost_type == KANN_C_CEB_NEG) {
773 t = kad_tanh(t);
774 cost = kad_ce_bin_neg(t, truth);
775 } else if (cost_type == KANN_C_CEM) {
776 t = kad_softmax(t);
777 cost = kad_ce_multi(t, truth);
778 }
779 else {
780 assert (0);
781 }
782
783 t->ext_flag |= KANN_F_OUT;
784 cost->ext_flag |= KANN_F_COST;
785
786 return cost;
787 }
788
kann_shuffle(int n,int * s)789 void kann_shuffle(int n, int *s)
790 {
791 int i, j, t;
792 for (i = 0; i < n; ++i) s[i] = i;
793 for (i = n; i > 0; --i) {
794 j = (int)(i * kad_drand(0));
795 t = s[j], s[j] = s[i-1], s[i-1] = t;
796 }
797 }
798
799 /***************************
800 *** @@MIN: minimization ***
801 ***************************/
802
803 #ifdef __SSE__
804 #include <xmmintrin.h>
805
kann_RMSprop(int n,float h0,const float * h,float decay,const float * g,float * t,float * r)806 void kann_RMSprop(int n, float h0, const float *h, float decay, const float *g, float *t, float *r)
807 {
808 int i, n4 = n>>2<<2;
809 __m128 vh, vg, vr, vt, vd, vd1, tmp, vtiny;
810 vh = _mm_set1_ps(h0);
811 vd = _mm_set1_ps(decay);
812 vd1 = _mm_set1_ps(1.0f - decay);
813 vtiny = _mm_set1_ps(1e-6f);
814 for (i = 0; i < n4; i += 4) {
815 vt = _mm_loadu_ps(&t[i]);
816 vr = _mm_loadu_ps(&r[i]);
817 vg = _mm_loadu_ps(&g[i]);
818 if (h) vh = _mm_loadu_ps(&h[i]);
819 vr = _mm_add_ps(_mm_mul_ps(vd1, _mm_mul_ps(vg, vg)), _mm_mul_ps(vd, vr));
820 _mm_storeu_ps(&r[i], vr);
821 tmp = _mm_sub_ps(vt, _mm_mul_ps(_mm_mul_ps(vh, _mm_rsqrt_ps(_mm_add_ps(vtiny, vr))), vg));
822 _mm_storeu_ps(&t[i], tmp);
823 }
824 for (; i < n; ++i) {
825 r[i] = (1. - decay) * g[i] * g[i] + decay * r[i];
826 t[i] -= (h? h[i] : h0) / sqrtf(1e-6f + r[i]) * g[i];
827 }
828 }
829 #else
kann_RMSprop(int n,float h0,const float * h,float decay,const float * g,float * t,float * r)830 void kann_RMSprop(int n, float h0, const float *h, float decay, const float *g, float *t, float *r)
831 {
832 int i;
833 for (i = 0; i < n; ++i) {
834 float lr = h? h[i] : h0;
835 r[i] = (1.0f - decay) * g[i] * g[i] + decay * r[i];
836 t[i] -= lr / sqrtf(1e-6f + r[i]) * g[i];
837 }
838 }
839 #endif
840
kann_grad_clip(float thres,int n,float * g)841 float kann_grad_clip(float thres, int n, float *g)
842 {
843 int i;
844 double s2 = 0.0;
845 for (i = 0; i < n; ++i)
846 s2 += g[i] * g[i];
847 s2 = sqrt(s2);
848 if (s2 > thres)
849 for (i = 0, s2 = 1.0 / s2; i < n; ++i)
850 g[i] *= (float)s2;
851 return (float)s2 / thres;
852 }
853
854 /****************************************************************
855 *** @@XY: simpler API for network with a single input/output ***
856 ****************************************************************/
857
kann_train_fnn1(kann_t * ann,float lr,int mini_size,int max_epoch,int max_drop_streak,float frac_val,int n,float ** _x,float ** _y,kann_train_cb cb,void * ud)858 int kann_train_fnn1(kann_t *ann, float lr, int mini_size, int max_epoch,
859 int max_drop_streak, float frac_val, int n,
860 float **_x, float **_y,
861 kann_train_cb cb, void *ud)
862 {
863 int i, j, *shuf, n_train, n_val, n_in, n_out, n_var, n_const, drop_streak = 0, min_set = 0;
864 float **x, **y, *x1, *y1, *r, min_val_cost = FLT_MAX, *min_x, *min_c;
865
866 n_in = kann_dim_in(ann);
867 n_out = kann_dim_out(ann);
868 if (n_in < 0 || n_out < 0) return -1;
869 n_var = kann_size_var(ann);
870 n_const = kann_size_const(ann);
871 r = (float*)calloc(n_var, sizeof(float));
872 shuf = (int*)malloc(n * sizeof(int));
873 x = (float**)malloc(n * sizeof(float*));
874 y = (float**)malloc(n * sizeof(float*));
875 kann_shuffle(n, shuf);
876 for (j = 0; j < n; ++j)
877 x[j] = _x[shuf[j]], y[j] = _y[shuf[j]];
878 n_val = (int)(n * frac_val);
879 n_train = n - n_val;
880 min_x = (float*)malloc(n_var * sizeof(float));
881 min_c = (float*)malloc(n_const * sizeof(float));
882
883 x1 = (float*)malloc(n_in * mini_size * sizeof(float));
884 y1 = (float*)malloc(n_out * mini_size * sizeof(float));
885 kann_feed_bind(ann, KANN_F_IN, 0, &x1);
886 kann_feed_bind(ann, KANN_F_TRUTH, 0, &y1);
887
888 for (i = 0; i < max_epoch; ++i) {
889 int n_proc = 0, n_train_err = 0, n_val_err = 0, n_train_base = 0, n_val_base = 0;
890 double train_cost = 0.0, val_cost = 0.0;
891 kann_shuffle(n_train, shuf);
892 kann_switch(ann, 1);
893 while (n_proc < n_train) {
894 int b, c, ms = n_train - n_proc < mini_size? n_train - n_proc : mini_size;
895 for (b = 0; b < ms; ++b) {
896 memcpy(&x1[b*n_in], x[shuf[n_proc+b]], n_in * sizeof(float));
897 memcpy(&y1[b*n_out], y[shuf[n_proc+b]], n_out * sizeof(float));
898 }
899 kann_set_batch_size(ann, ms);
900 train_cost += kann_cost(ann, 0, 1) * ms;
901 c = kann_class_error(ann, &b);
902 n_train_err += c, n_train_base += b;
903 kann_RMSprop(n_var, lr, 0, 0.9f, ann->g, ann->x, r);
904 n_proc += ms;
905 }
906 train_cost /= n_train;
907 kann_switch(ann, 0);
908 n_proc = 0;
909 while (n_proc < n_val) {
910 int b, c, ms = n_val - n_proc < mini_size? n_val - n_proc : mini_size;
911 for (b = 0; b < ms; ++b) {
912 memcpy(&x1[b*n_in], x[n_train+n_proc+b], n_in * sizeof(float));
913 memcpy(&y1[b*n_out], y[n_train+n_proc+b], n_out * sizeof(float));
914 }
915 kann_set_batch_size(ann, ms);
916 val_cost += kann_cost(ann, 0, 0) * ms;
917 c = kann_class_error(ann, &b);
918 n_val_err += c, n_val_base += b;
919 n_proc += ms;
920 }
921 if (n_val > 0) val_cost /= n_val;
922 if (cb) {
923 cb(i + 1, train_cost, val_cost, ud);
924 #if 0
925 fprintf(stderr, "epoch: %d; training cost: %g", i+1, train_cost);
926 if (n_train_base) fprintf(stderr, " (class error: %.2f%%)", 100.0f * n_train_err / n_train);
927 if (n_val > 0) {
928 fprintf(stderr, "; validation cost: %g", val_cost);
929 if (n_val_base) fprintf(stderr, " (class error: %.2f%%)", 100.0f * n_val_err / n_val);
930 }
931 fputc('\n', stderr);
932 #endif
933 }
934 if (i >= max_drop_streak && n_val > 0) {
935 if (val_cost < min_val_cost) {
936 min_set = 1;
937 memcpy(min_x, ann->x, n_var * sizeof(float));
938 memcpy(min_c, ann->c, n_const * sizeof(float));
939 drop_streak = 0;
940 min_val_cost = (float)val_cost;
941 } else if (++drop_streak >= max_drop_streak)
942 break;
943 }
944 }
945 if (min_set) {
946 memcpy(ann->x, min_x, n_var * sizeof(float));
947 memcpy(ann->c, min_c, n_const * sizeof(float));
948 }
949
950 free(min_c); free(min_x); free(y1); free(x1); free(y); free(x); free(shuf); free(r);
951 return i;
952 }
953
kann_cost_fnn1(kann_t * ann,int n,float ** x,float ** y)954 float kann_cost_fnn1(kann_t *ann, int n, float **x, float **y)
955 {
956 int n_in, n_out, n_proc = 0, mini_size = 64 < n? 64 : n;
957 float *x1, *y1;
958 double cost = 0.0;
959
960 n_in = kann_dim_in(ann);
961 n_out = kann_dim_out(ann);
962 if (n <= 0 || n_in < 0 || n_out < 0) return 0.0;
963
964 x1 = (float*)malloc(n_in * mini_size * sizeof(float));
965 y1 = (float*)malloc(n_out * mini_size * sizeof(float));
966 kann_feed_bind(ann, KANN_F_IN, 0, &x1);
967 kann_feed_bind(ann, KANN_F_TRUTH, 0, &y1);
968 kann_switch(ann, 0);
969 while (n_proc < n) {
970 int b, ms = n - n_proc < mini_size? n - n_proc : mini_size;
971 for (b = 0; b < ms; ++b) {
972 memcpy(&x1[b*n_in], x[n_proc+b], n_in * sizeof(float));
973 memcpy(&y1[b*n_out], y[n_proc+b], n_out * sizeof(float));
974 }
975 kann_set_batch_size(ann, ms);
976 cost += kann_cost(ann, 0, 0) * ms;
977 n_proc += ms;
978 }
979 free(y1); free(x1);
980 return (float)(cost / n);
981 }
982
kann_apply1(kann_t * a,float * x)983 const float *kann_apply1(kann_t *a, float *x)
984 {
985 int i_out;
986 i_out = kann_find(a, KANN_F_OUT, 0);
987 if (i_out < 0) return 0;
988 kann_set_batch_size(a, 1);
989 kann_feed_bind(a, KANN_F_IN, 0, &x);
990 kad_eval_at(a->n, a->v, i_out);
991 return a->v[i_out]->x;
992 }
993