1 // This is mul/clsfy/clsfy_random_forest.cxx
2 #include <string>
3 #include <deque>
4 #include <iostream>
5 #include <algorithm>
6 #include <iterator>
7 #include <cmath>
8 #include "clsfy_random_forest.h"
9 //:
10 // \file
11 // \brief Random forest classifier
12 // \author Martin Roberts
13
14 #ifdef _MSC_VER
15 # include "vcl_msvc_warnings.h"
16 #endif
17 #include <cassert>
18 #include "vsl/vsl_binary_io.h"
19 #include "vsl/vsl_vector_io.h"
20 #include <vnl/io/vnl_io_vector.h>
21 #include <mbl/mbl_cloneable_ptr.h>
22
23
24 clsfy_random_forest::clsfy_random_forest() = default;
25
26 //=======================================================================
27 //: Return the classification of the given probe vector.
classify(const vnl_vector<double> & input) const28 unsigned clsfy_random_forest::classify(const vnl_vector<double> &input) const
29 {
30 #if 1 //Accumulate probabilities (impure final nodes may not return 0 or 1)
31 std::vector<double > classProbs(1,0.0);
32 class_probabilities(classProbs,input);
33 return (classProbs[0]>=0.5) ? 1 : 0;
34 #else // just accumulate number in each class rather than probs
35
36 std::vector<mbl_cloneable_ptr<clsfy_classifier_base> >::const_iterator treeIter=trees_.begin();
37 std::vector<mbl_cloneable_ptr<clsfy_classifier_base> >::const_iterator treeIterEnd=trees_.end();
38
39 std::vector<unsigned > classCount(2,0);
40
41 unsigned i=0;
42 while (treeIter != treeIterEnd)
43 {
44 mbl_cloneable_ptr<clsfy_classifier_base> pTree=*treeIter++;
45 unsigned treeClass= pTree->classify(input);
46
47 ++classCount[treeClass];
48 }
49 if (classCount[0] >= classCount[1])
50 return 0;
51 else
52 return 1;
53 #endif // 1
54 }
55
56 //=======================================================================
57 //: Return a probability like value that the input being in each class.
58 // output(i) i<<nClasses, contains the probability that the input is in class i
class_probabilities(std::vector<double> & outputs,vnl_vector<double> const & input) const59 void clsfy_random_forest::class_probabilities(std::vector<double>& outputs,
60 vnl_vector<double>const& input) const
61 {
62 outputs.resize(1);
63
64 auto treeIter=trees_.begin();
65 auto treeIterEnd=trees_.end();
66
67 std::vector<double > classProbs(1,0.0);
68 std::vector<double > meanProbs(1,0.0);
69
70 while (treeIter != treeIterEnd)
71 {
72 const clsfy_classifier_base* pTree=(*treeIter).ptr();
73 pTree->class_probabilities(classProbs, input);
74 meanProbs[0]+=classProbs[0];
75 ++treeIter;
76 }
77 outputs[0]=meanProbs[0]/double (trees_.size());
78 }
79
80
81 //=======================================================================
82 //: This value has properties of a Log likelihood of being in class (binary classifiers only)
83 // class probability = exp(logL) / (1+exp(logL))
log_l(const vnl_vector<double> & input) const84 double clsfy_random_forest::log_l(const vnl_vector<double> &input) const
85 {
86 //Retain logistic function relation to prob
87 //i.e. invert the above relation
88 double epsilon=1.0E-8;
89 std::vector<double > probs(1,0.5);
90 class_probabilities(probs,input);
91 double p=probs[0];
92 double d=(1.0/p)-1.0;
93 double x=1.0;
94 if (d>epsilon)
95 x=-std::log(d);
96 else
97 x=-std::log(epsilon);
98
99 return x;
100 }
101
102 //======================= Out of Bag add-ons ==============================
class_probabilities_oob(std::vector<double> & outputs,const vnl_vector<double> & input,const std::vector<std::vector<unsigned>> & oobIndices,unsigned this_index) const103 void clsfy_random_forest::class_probabilities_oob(std::vector<double> &outputs,
104 const vnl_vector<double> &input,
105 const std::vector<std::vector<unsigned > >& oobIndices,
106 unsigned this_index) const
107 {
108 outputs.resize(1);
109
110 auto treeIter=trees_.begin();
111 auto treeIterEnd=trees_.end();
112
113 std::vector<double > classProbs(1,0.0);
114 std::vector<double > meanProbs(1,0.0);
115 auto oobIndexIter=oobIndices.begin() ;
116 unsigned noob=0;
117 while (treeIter != treeIterEnd)
118 {
119 if (std::find(oobIndexIter->begin(),oobIndexIter->end(),this_index)==oobIndexIter->end())
120 {
121 //Not found this_index, so Out of Bag - accumulate this tree's vote
122 const clsfy_classifier_base* pTree=(*treeIter).ptr();
123
124 pTree->class_probabilities(classProbs, input);
125 meanProbs[0]+=classProbs[0];
126 ++noob;
127 }
128 ++treeIter;
129 ++oobIndexIter;
130 }
131 outputs[0]=meanProbs[0]/double (noob);
132 }
133
134
135 //: Return the classification of the given probe vector using out of bag trees only.
136 // See also class_probabilities_oob
classify_oob(const vnl_vector<double> & input,const std::vector<std::vector<unsigned>> & oobIndices,unsigned this_index) const137 unsigned clsfy_random_forest::classify_oob(const vnl_vector<double> &input,
138 const std::vector<std::vector<unsigned > >& oobIndices,
139 unsigned this_index) const
140 {
141 std::vector<double > classProbs(1,0.0);
142 class_probabilities_oob(classProbs,input,oobIndices,this_index);
143 return (classProbs[0]>=0.5) ? 1 : 0;
144 }
145
146
147 //=======================================================================
148
is_a() const149 std::string clsfy_random_forest::is_a() const
150 {
151 return std::string("clsfy_random_forest");
152 }
153
154 //=======================================================================
155
is_class(std::string const & s) const156 bool clsfy_random_forest::is_class(std::string const& s) const
157 {
158 return s == clsfy_random_forest::is_a() || clsfy_classifier_base::is_class(s);
159 }
160
161 //=======================================================================
162
version_no() const163 short clsfy_random_forest::version_no() const
164 {
165 return 1;
166 }
167
168 //=======================================================================
169
clone() const170 clsfy_classifier_base* clsfy_random_forest::clone() const
171 {
172 return new clsfy_random_forest(*this);
173 }
174
175 //=======================================================================
176
print_summary(std::ostream & os) const177 void clsfy_random_forest::print_summary(std::ostream& os) const
178 {
179 os<<"clsfy_random_forest\t has "<<trees_.size()<<" trees"<<std::endl;
180 }
181
182 //=======================================================================
183
b_write(vsl_b_ostream & bfs) const184 void clsfy_random_forest::b_write(vsl_b_ostream& bfs) const
185 {
186 std::cout<<"clsfy_random_forest::b_write"<<std::endl;
187 vsl_b_write(bfs,version_no());
188 unsigned n=trees_.size();
189 vsl_b_write(bfs,n);
190 for (unsigned i=0; i<n;++i)
191 {
192 trees_[i]->b_write(bfs);
193 }
194 }
195
196 //=======================================================================
197
b_read(vsl_b_istream & bfs)198 void clsfy_random_forest::b_read(vsl_b_istream& bfs)
199 {
200 if (!bfs) return;
201
202 prune();
203 short version;
204 vsl_b_read(bfs,version);
205 switch (version)
206 {
207 case 1:
208 {
209 unsigned n;
210 vsl_b_read(bfs,n);
211 std::cout<<"Am attemptig to read in "<<n<<"\t trees"<<std::endl;
212 trees_.reserve(n);
213 for (unsigned i=0; i<n;++i)
214 {
215 // std::cout<<"reading tree "<<i<<std::endl;
216 mbl_cloneable_ptr< clsfy_classifier_base> tree(new clsfy_binary_tree);
217 trees_.push_back(tree);
218 trees_.back()->b_read(bfs);
219 }
220 break;
221 }
222
223 default:
224 std::cerr << "I/O ERROR: clsfy_random_forest::b_read(vsl_b_istream&)\n"
225 << " Unknown version number "<< version << '\n';
226 bfs.is().clear(std::ios::badbit); // Set an unrecoverable IO error on stream
227 }
228 }
229
~clsfy_random_forest()230 clsfy_random_forest::~clsfy_random_forest()
231 {
232 prune();
233 }
234
prune()235 void clsfy_random_forest::prune()
236 {
237 trees_.clear(); //note mbl wrapper destructor deletes the tree pointer!
238 }
239
240 //=======================================================================
241 //: The dimensionality of input vectors.
n_dims() const242 unsigned clsfy_random_forest::n_dims() const
243 {
244 if (trees_.empty())
245 return 0;
246 else
247 return trees_.front()->n_dims();
248 }
249
operator +=(const clsfy_random_forest & forest2)250 clsfy_random_forest& clsfy_random_forest::operator+=(const clsfy_random_forest& forest2)
251 {
252 this->trees_.reserve(this->trees_.size()+forest2.trees_.size());
253 this->trees_.insert(this->trees_.end(),
254 forest2.trees_.begin(),forest2.trees_.end());
255 return *this;
256 }
257
258
259 //============ Friend functions for merging stuff ====================
260
261 //: Merge the sub-forests in the input filenames into a single larger one
merge_sub_forests(const std::vector<std::string> & filenames,clsfy_random_forest & large_forest)262 void merge_sub_forests(const std::vector<std::string>& filenames,
263 clsfy_random_forest& large_forest)
264 {
265 auto fileIter=filenames.begin();
266 auto fileIterEnd=filenames.end();
267 while (fileIter != fileIterEnd)
268 {
269 vsl_b_ifstream bfs_in(*fileIter);
270 clsfy_random_forest subForest;
271 vsl_b_read(bfs_in, subForest);
272 bfs_in.close();
273 large_forest.trees_.reserve(large_forest.trees_.size()+subForest.trees_.size());
274 large_forest.trees_.insert(large_forest.trees_.end(),
275 subForest.trees_.begin(),subForest.trees_.end());
276 ++fileIter;
277 }
278 }
279
280 //: Merge the sub-forests pointed to the input vector a single larger one
merge_sub_forests(const std::vector<clsfy_random_forest * > & sub_forests,clsfy_random_forest & large_forest)281 void merge_sub_forests(const std::vector< clsfy_random_forest*>& sub_forests,
282 clsfy_random_forest& large_forest)
283 {
284 auto subForestIter=sub_forests.begin();
285 auto subForestIterEnd=sub_forests.end();
286 while (subForestIter != subForestIterEnd)
287 {
288 const clsfy_random_forest& subForest=**subForestIter;
289 large_forest.trees_.reserve(large_forest.trees_.size()+subForest.trees_.size());
290 large_forest.trees_.insert(large_forest.trees_.end(),
291 subForest.trees_.begin(),subForest.trees_.end());
292 ++subForestIter;
293 }
294 }
295
296 //: Merge the two input forests
operator +(const clsfy_random_forest & forest1,const clsfy_random_forest & forest2)297 clsfy_random_forest operator+(const clsfy_random_forest& forest1,
298 const clsfy_random_forest& forest2)
299 {
300 clsfy_random_forest mergedForest=forest1;
301
302 mergedForest.trees_.reserve(forest1.trees_.size()+forest2.trees_.size());
303 mergedForest.trees_.insert(mergedForest.trees_.end(),
304 forest2.trees_.begin(),forest2.trees_.end());
305 return mergedForest;
306 }
307