1 #include <stdio.h>
2 #include <stdlib.h>
3 #include "tree.h"
4 #include "utils.h"
5 #include "data.h"
6 
change_leaves(tree * t,char * leaf_list)7 void change_leaves(tree *t, char *leaf_list)
8 {
9     list *llist = get_paths(leaf_list);
10     char **leaves = (char **)list_to_array(llist);
11     int n = llist->size;
12     int i,j;
13     int found = 0;
14     for(i = 0; i < t->n; ++i){
15         t->leaf[i] = 0;
16         for(j = 0; j < n; ++j){
17             if (0==strcmp(t->name[i], leaves[j])){
18                 t->leaf[i] = 1;
19                 ++found;
20                 break;
21             }
22         }
23     }
24     fprintf(stderr, "Found %d leaves.\n", found);
25 }
26 
get_hierarchy_probability(float * x,tree * hier,int c)27 float get_hierarchy_probability(float *x, tree *hier, int c)
28 {
29     float p = 1;
30     while(c >= 0){
31         p = p * x[c];
32         c = hier->parent[c];
33     }
34     return p;
35 }
36 
hierarchy_predictions(float * predictions,int n,tree * hier,int only_leaves)37 void hierarchy_predictions(float *predictions, int n, tree *hier, int only_leaves)
38 {
39     int j;
40     for(j = 0; j < n; ++j){
41         int parent = hier->parent[j];
42         if(parent >= 0){
43             predictions[j] *= predictions[parent];
44         }
45     }
46     if(only_leaves){
47         for(j = 0; j < n; ++j){
48             if(!hier->leaf[j]) predictions[j] = 0;
49         }
50     }
51 }
52 
hierarchy_top_prediction(float * predictions,tree * hier,float thresh,int stride)53 int hierarchy_top_prediction(float *predictions, tree *hier, float thresh, int stride)
54 {
55     float p = 1;
56     int group = 0;
57     int i;
58     while (1) {
59         float max = 0;
60         int max_i = 0;
61 
62         for (i = 0; i < hier->group_size[group]; ++i) {
63             int index = i + hier->group_offset[group];
64             float val = predictions[(i + hier->group_offset[group])*stride];
65             if (val > max) {
66                 max_i = index;
67                 max = val;
68             }
69         }
70         if (p*max > thresh) {
71             p = p*max;
72             group = hier->child[max_i];
73             if (hier->child[max_i] < 0) return max_i;
74         }
75         else if (group == 0) {
76             return max_i;
77         }
78         else {
79             return hier->parent[hier->group_offset[group]];
80         }
81     }
82     return 0;
83 }
84 
read_tree(char * filename)85 tree *read_tree(char *filename)
86 {
87     tree t = {0};
88     FILE *fp = fopen(filename, "r");
89 
90     char *line;
91     int last_parent = -1;
92     int group_size = 0;
93     int groups = 0;
94     int n = 0;
95     while((line=fgetl(fp)) != 0){
96         char* id = (char*)xcalloc(256, sizeof(char));
97         int parent = -1;
98         sscanf(line, "%s %d", id, &parent);
99         t.parent = (int*)xrealloc(t.parent, (n + 1) * sizeof(int));
100         t.parent[n] = parent;
101 
102         t.name = (char**)xrealloc(t.name, (n + 1) * sizeof(char*));
103         t.name[n] = id;
104         if(parent != last_parent){
105             ++groups;
106             t.group_offset = (int*)xrealloc(t.group_offset, groups * sizeof(int));
107             t.group_offset[groups - 1] = n - group_size;
108             t.group_size = (int*)xrealloc(t.group_size, groups * sizeof(int));
109             t.group_size[groups - 1] = group_size;
110             group_size = 0;
111             last_parent = parent;
112         }
113         t.group = (int*)xrealloc(t.group, (n + 1) * sizeof(int));
114         t.group[n] = groups;
115         ++n;
116         ++group_size;
117     }
118     ++groups;
119     t.group_offset = (int*)xrealloc(t.group_offset, groups * sizeof(int));
120     t.group_offset[groups - 1] = n - group_size;
121     t.group_size = (int*)xrealloc(t.group_size, groups * sizeof(int));
122     t.group_size[groups - 1] = group_size;
123     t.n = n;
124     t.groups = groups;
125     t.leaf = (int*)xcalloc(n, sizeof(int));
126     int i;
127     for(i = 0; i < n; ++i) t.leaf[i] = 1;
128     for(i = 0; i < n; ++i) if(t.parent[i] >= 0) t.leaf[t.parent[i]] = 0;
129 
130     fclose(fp);
131     tree* tree_ptr = (tree*)xcalloc(1, sizeof(tree));
132     *tree_ptr = t;
133     //error(0);
134     return tree_ptr;
135 }
136