1 /************************************************************************/
2 /* */
3 /* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */
4 /* */
5 /* This file is part of the VIGRA computer vision library. */
6 /* The VIGRA Website is */
7 /* http://hci.iwr.uni-heidelberg.de/vigra/ */
8 /* Please direct questions, bug reports, and contributions to */
9 /* ullrich.koethe@iwr.uni-heidelberg.de or */
10 /* vigra@informatik.uni-hamburg.de */
11 /* */
12 /* Permission is hereby granted, free of charge, to any person */
13 /* obtaining a copy of this software and associated documentation */
14 /* files (the "Software"), to deal in the Software without */
15 /* restriction, including without limitation the rights to use, */
16 /* copy, modify, merge, publish, distribute, sublicense, and/or */
17 /* sell copies of the Software, and to permit persons to whom the */
18 /* Software is furnished to do so, subject to the following */
19 /* conditions: */
20 /* */
21 /* The above copyright notice and this permission notice shall be */
22 /* included in all copies or substantial portions of the */
23 /* Software. */
24 /* */
25 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */
26 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */
27 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */
28 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */
29 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */
30 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */
31 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */
32 /* OTHER DEALINGS IN THE SOFTWARE. */
33 /* */
34 /************************************************************************/
35
36
37 #ifndef VIGRA_RANDOM_FOREST_HXX
38 #define VIGRA_RANDOM_FOREST_HXX
39
40 #include <iostream>
41 #include <algorithm>
42 #include <map>
43 #include <set>
44 #include <list>
45 #include <numeric>
46 #include "mathutil.hxx"
47 #include "array_vector.hxx"
48 #include "sized_int.hxx"
49 #include "matrix.hxx"
50 #include "metaprogramming.hxx"
51 #include "random.hxx"
52 #include "functorexpression.hxx"
53 #include "random_forest/rf_common.hxx"
54 #include "random_forest/rf_nodeproxy.hxx"
55 #include "random_forest/rf_split.hxx"
56 #include "random_forest/rf_decisionTree.hxx"
57 #include "random_forest/rf_visitors.hxx"
58 #include "random_forest/rf_region.hxx"
59 #include "sampling.hxx"
60 #include "random_forest/rf_preprocessing.hxx"
61 #include "random_forest/rf_online_prediction_set.hxx"
62 #include "random_forest/rf_earlystopping.hxx"
63 #include "random_forest/rf_ridge_split.hxx"
64 namespace vigra
65 {
66
67 /** \addtogroup MachineLearning Machine Learning
68
69 This module provides classification algorithms that map
70 features to labels or label probabilities.
71 Look at the \ref vigra::RandomForest class (for implementation version 2) or the
72 \ref vigra::rf3::random_forest() factory function (for implementation version 3)
73 for an overview of the functionality as well as use cases.
74 **/
75
76 namespace detail
77 {
78
79
80
81 /* \brief sampling option factory function
82 */
make_sampler_opt(RandomForestOptions & RF_opt)83 inline SamplerOptions make_sampler_opt ( RandomForestOptions & RF_opt)
84 {
85 SamplerOptions return_opt;
86 return_opt.withReplacement(RF_opt.sample_with_replacement_);
87 return_opt.stratified(RF_opt.stratification_method_ == RF_EQUAL);
88 return return_opt;
89 }
90 }//namespace detail
91
92 /** \brief Random forest version 2 (see also \ref vigra::rf3::RandomForest for version 3)
93 *
94 * \ingroup MachineLearning
95 *
96 * \tparam <LabelType = double> Type used for predicted labels.
97 * \tparam <PreprocessorTag = ClassificationTag> Class used to preprocess
98 * the input while learning and predicting. Currently Available:
99 * ClassificationTag and RegressionTag. It is recommended to use
100 * Splitfunctor::Preprocessor_t while using custom splitfunctors
101 * as they may need the data to be in a different format.
102 * \sa Preprocessor
103 *
104 * Simple usage for classification (regression is not yet supported):
105 * look at RandomForest::learn() as well as RandomForestOptions() for additional
106 * options.
107 *
108 * \code
109 * using namespace vigra;
110 * using namespace rf;
111 * typedef xxx feature_t; \\ replace xxx with whichever type
112 * typedef yyy label_t; \\ likewise
113 *
114 * // allocate the training data
115 * MultiArrayView<2, feature_t> f = get_training_features();
116 * MultiArrayView<2, label_t> l = get_training_labels();
117 *
118 * RandomForest<label_t> rf;
119 *
120 * // construct visitor to calculate out-of-bag error
121 * visitors::OOB_Error oob_v;
122 *
123 * // perform training
124 * rf.learn(f, l, visitors::create_visitor(oob_v));
125 *
126 * std::cout << "the out-of-bag error is: " << oob_v.oob_breiman << "\n";
127 *
128 * // get features for new data to be used for prediction
129 * MultiArrayView<2, feature_t> pf = get_features();
130 *
131 * // allocate space for the response (pf.shape(0) is the number of samples)
132 * MultiArrayView<2, label_t> prediction(pf.shape(0), 1);
133 * MultiArrayView<2, double> prob(pf.shape(0), rf.class_count());
134 *
135 * // perform prediction on new data
136 * rf.predictLabels(pf, prediction);
137 * rf.predictProbabilities(pf, prob);
138 *
139 * \endcode
140 *
141 * Additional information such as Variable Importance measures are accessed
142 * via Visitors defined in rf::visitors.
143 * Have a look at rf::split for other splitting methods.
144 *
145 */
146 template <class LabelType = double , class PreprocessorTag = ClassificationTag >
147 class RandomForest
148 {
149
150 public:
151 //public typedefs
152 typedef RandomForestOptions Options_t;
153 typedef detail::DecisionTree DecisionTree_t;
154 typedef ProblemSpec<LabelType> ProblemSpec_t;
155 typedef GiniSplit Default_Split_t;
156 typedef EarlyStoppStd Default_Stop_t;
157 typedef rf::visitors::StopVisiting Default_Visitor_t;
158 typedef DT_StackEntry<ArrayVectorView<Int32>::iterator>
159 StackEntry_t;
160 typedef LabelType LabelT;
161
162 //problem independent data.
163 Options_t options_;
164 //problem dependent data members - is only set if
165 //a copy constructor, some sort of import
166 //function or the learn function is called
167 ArrayVector<DecisionTree_t> trees_;
168 ProblemSpec_t ext_param_;
169 /*mutable ArrayVector<int> tree_indices_;*/
170 rf::visitors::OnlineLearnVisitor online_visitor_;
171
172
reset()173 void reset()
174 {
175 ext_param_.clear();
176 trees_.clear();
177 }
178
179 public:
180
181 /** \name Constructors
182 * Note: No copy constructor specified as no pointers are manipulated
183 * in this class
184
185 * @{
186 */
187
188 /**\brief default constructor
189 *
190 * \param options general options to the Random Forest. Must be of Type
191 * Options_t
192 * \param ext_param problem specific values that can be supplied
193 * additionally. (class weights , labels etc)
194 * \sa RandomForestOptions, ProblemSpec
195 *
196 */
RandomForest(Options_t const & options=Options_t (),ProblemSpec_t const & ext_param=ProblemSpec_t ())197 RandomForest(Options_t const & options = Options_t(),
198 ProblemSpec_t const & ext_param = ProblemSpec_t())
199 :
200 options_(options),
201 ext_param_(ext_param)/*,
202 tree_indices_(options.tree_count_,0)*/
203 {
204 /*for(int ii = 0 ; ii < int(tree_indices_.size()); ++ii)
205 tree_indices_[ii] = ii;*/
206 }
207
208 /**\brief Create RF from external source
209 * \param treeCount Number of trees to add.
210 * \param topology_begin
211 * Iterator to a Container where the topology_ data
212 * of the trees are stored.
213 * Iterator should support at least treeCount forward
214 * iterations. (i.e. topology_end - topology_begin >= treeCount
215 * \param parameter_begin
216 * iterator to a Container where the parameters_ data
217 * of the trees are stored. Iterator should support at
218 * least treeCount forward iterations.
219 * \param problem_spec
220 * Extrinsic parameters that specify the problem e.g.
221 * ClassCount, featureCount etc.
222 * \param options (optional) specify options used to train the original
223 * Random forest. This parameter is not used anywhere
224 * during prediction and thus is optional.
225 *
226 */
227 template<class TopologyIterator, class ParameterIterator>
RandomForest(int treeCount,TopologyIterator topology_begin,ParameterIterator parameter_begin,ProblemSpec_t const & problem_spec,Options_t const & options=Options_t ())228 RandomForest(int treeCount,
229 TopologyIterator topology_begin,
230 ParameterIterator parameter_begin,
231 ProblemSpec_t const & problem_spec,
232 Options_t const & options = Options_t())
233 :
234 trees_(treeCount, DecisionTree_t(problem_spec)),
235 ext_param_(problem_spec),
236 options_(options)
237 {
238 /* TODO: This constructor may be replaced by a Constructor using
239 * NodeProxy iterators to encapsulate the underlying data type.
240 */
241 for(int k=0; k<treeCount; ++k, ++topology_begin, ++parameter_begin)
242 {
243 trees_[k].topology_ = *topology_begin;
244 trees_[k].parameters_ = *parameter_begin;
245 }
246 }
247
248 /** @} */
249
250
251 /** \name Data Access
252 * data access interface - usage of member variables is deprecated
253 *
254 * @{
255 */
256
257 /**\brief return external parameters for viewing
258 * \return ProblemSpec_t
259 */
ext_param() const260 ProblemSpec_t const & ext_param() const
261 {
262 vigra_precondition(ext_param_.used() == true,
263 "RandomForest::ext_param(): "
264 "Random forest has not been trained yet.");
265 return ext_param_;
266 }
267
268 /**\brief set external parameters
269 *
270 * \param in external parameters to be set
271 *
272 * set external parameters explicitly.
273 * If Random Forest has not been trained the preprocessor will
274 * either ignore filling values set this way or will throw an exception
275 * if values specified manually do not match the value calculated
276 & during the preparation step.
277 */
set_ext_param(ProblemSpec_t const & in)278 void set_ext_param(ProblemSpec_t const & in)
279 {
280 ignore_argument(in);
281 vigra_precondition(ext_param_.used() == false,
282 "RandomForest::set_ext_param():"
283 "Random forest has been trained! Call reset()"
284 "before specifying new extrinsic parameters.");
285 }
286
287 /**\brief access random forest options
288 *
289 * \return random forest options
290 */
set_options()291 Options_t & set_options()
292 {
293 return options_;
294 }
295
296
297 /**\brief access const random forest options
298 *
299 * \return const Option_t
300 */
options() const301 Options_t const & options() const
302 {
303 return options_;
304 }
305
306 /**\brief access const trees
307 */
tree(int index) const308 DecisionTree_t const & tree(int index) const
309 {
310 return trees_[index];
311 }
312
313 /**\brief access trees
314 */
tree(int index)315 DecisionTree_t & tree(int index)
316 {
317 return trees_[index];
318 }
319
320 /**\brief return number of features used while
321 * training.
322 */
feature_count() const323 int feature_count() const
324 {
325 return ext_param_.column_count_;
326 }
327
328
329 /**\brief return number of features used while
330 * training.
331 *
332 * deprecated. Use feature_count() instead.
333 */
column_count() const334 int column_count() const
335 {
336 return ext_param_.column_count_;
337 }
338
339 /**\brief return number of classes used while
340 * training.
341 */
class_count() const342 int class_count() const
343 {
344 return ext_param_.class_count_;
345 }
346
347 /**\brief return number of trees
348 */
tree_count() const349 int tree_count() const
350 {
351 return options_.tree_count_;
352 }
353
354 /** @} */
355
356 /**\name Learning
357 * Following functions differ in the degree of customization
358 * allowed
359 *
360 * @{
361 */
362
363 /**\brief learn on data with custom config and random number generator
364 *
365 * \param features a N x M matrix containing N samples with M
366 * features
367 * \param response a N x D matrix containing the corresponding
368 * response. Current split functors assume D to
369 * be 1 and ignore any additional columns.
370 * This is not enforced to allow future support
371 * for uncertain labels, label independent strata etc.
372 * The Preprocessor specified during construction
373 * should be able to handle features and labels
374 * features and the labels.
375 * see also: SplitFunctor, Preprocessing
376 *
377 * \param visitor visitor which is to be applied after each split,
378 * tree and at the end. Use rf_default() for using
379 * default value. (No Visitors)
380 * see also: rf::visitors
381 * \param split split functor to be used to calculate each split
382 * use rf_default() for using default value. (GiniSplit)
383 * see also: rf::split
384 * \param stop
385 * predicate to be used to calculate each split
386 * use rf_default() for using default value. (EarlyStoppStd)
387 * \param random RandomNumberGenerator to be used. Use
388 * rf_default() to use default value.(RandomMT19337)
389 *
390 *
391 */
392 template <class U, class C1,
393 class U2,class C2,
394 class Split_t,
395 class Stop_t,
396 class Visitor_t,
397 class Random_t>
398 void learn( MultiArrayView<2, U, C1> const & features,
399 MultiArrayView<2, U2,C2> const & response,
400 Visitor_t visitor,
401 Split_t split,
402 Stop_t stop,
403 Random_t const & random);
404
405 template <class U, class C1,
406 class U2,class C2,
407 class Split_t,
408 class Stop_t,
409 class Visitor_t>
learn(MultiArrayView<2,U,C1> const & features,MultiArrayView<2,U2,C2> const & response,Visitor_t visitor,Split_t split,Stop_t stop)410 void learn( MultiArrayView<2, U, C1> const & features,
411 MultiArrayView<2, U2,C2> const & response,
412 Visitor_t visitor,
413 Split_t split,
414 Stop_t stop)
415
416 {
417 RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed);
418 learn( features,
419 response,
420 visitor,
421 split,
422 stop,
423 rnd);
424 }
425
426 template <class U, class C1, class U2,class C2, class Visitor_t>
learn(MultiArrayView<2,U,C1> const & features,MultiArrayView<2,U2,C2> const & labels,Visitor_t visitor)427 void learn( MultiArrayView<2, U, C1> const & features,
428 MultiArrayView<2, U2,C2> const & labels,
429 Visitor_t visitor)
430 {
431 learn( features,
432 labels,
433 visitor,
434 rf_default(),
435 rf_default());
436 }
437
438 template <class U, class C1, class U2,class C2,
439 class Visitor_t, class Split_t>
learn(MultiArrayView<2,U,C1> const & features,MultiArrayView<2,U2,C2> const & labels,Visitor_t visitor,Split_t split)440 void learn( MultiArrayView<2, U, C1> const & features,
441 MultiArrayView<2, U2,C2> const & labels,
442 Visitor_t visitor,
443 Split_t split)
444 {
445 learn( features,
446 labels,
447 visitor,
448 split,
449 rf_default());
450 }
451
452 /**\brief learn on data with default configuration
453 *
454 * \param features a N x M matrix containing N samples with M
455 * features
456 * \param labels a N x D matrix containing the corresponding
457 * N labels. Current split functors assume D to
458 * be 1 and ignore any additional columns.
459 * this is not enforced to allow future support
460 * for uncertain labels.
461 *
462 * learning is done with:
463 *
464 * \sa rf::split, EarlyStoppStd
465 *
466 * - Randomly seeded random number generator
467 * - default gini split functor as described by Breiman
468 * - default The standard early stopping criterion
469 */
470 template <class U, class C1, class U2,class C2>
learn(MultiArrayView<2,U,C1> const & features,MultiArrayView<2,U2,C2> const & labels)471 void learn( MultiArrayView<2, U, C1> const & features,
472 MultiArrayView<2, U2,C2> const & labels)
473 {
474 learn( features,
475 labels,
476 rf_default(),
477 rf_default(),
478 rf_default());
479 }
480
481
482 template<class U,class C1,
483 class U2, class C2,
484 class Split_t,
485 class Stop_t,
486 class Visitor_t,
487 class Random_t>
488 void onlineLearn( MultiArrayView<2,U,C1> const & features,
489 MultiArrayView<2,U2,C2> const & response,
490 int new_start_index,
491 Visitor_t visitor_,
492 Split_t split_,
493 Stop_t stop_,
494 Random_t & random,
495 bool adjust_thresholds=false);
496
497 template <class U, class C1, class U2,class C2>
onlineLearn(MultiArrayView<2,U,C1> const & features,MultiArrayView<2,U2,C2> const & labels,int new_start_index,bool adjust_thresholds=false)498 void onlineLearn( MultiArrayView<2, U, C1> const & features,
499 MultiArrayView<2, U2,C2> const & labels,int new_start_index,bool adjust_thresholds=false)
500 {
501 RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed);
502 onlineLearn(features,
503 labels,
504 new_start_index,
505 rf_default(),
506 rf_default(),
507 rf_default(),
508 rnd,
509 adjust_thresholds);
510 }
511
512 template<class U,class C1,
513 class U2, class C2,
514 class Split_t,
515 class Stop_t,
516 class Visitor_t,
517 class Random_t>
518 void reLearnTree(MultiArrayView<2,U,C1> const & features,
519 MultiArrayView<2,U2,C2> const & response,
520 int treeId,
521 Visitor_t visitor_,
522 Split_t split_,
523 Stop_t stop_,
524 Random_t & random);
525
526 template<class U, class C1, class U2, class C2>
reLearnTree(MultiArrayView<2,U,C1> const & features,MultiArrayView<2,U2,C2> const & labels,int treeId)527 void reLearnTree(MultiArrayView<2, U, C1> const & features,
528 MultiArrayView<2, U2, C2> const & labels,
529 int treeId)
530 {
531 RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed);
532 reLearnTree(features,
533 labels,
534 treeId,
535 rf_default(),
536 rf_default(),
537 rf_default(),
538 rnd);
539 }
540
541 /** @} */
542
543
544
545 /**\name Prediction
546 *
547 * @{
548 */
549
550 /** \brief predict a label given a feature.
551 *
552 * \param features: a 1 by featureCount matrix containing
553 * data point to be predicted (this only works in
554 * classification setting)
555 * \param stop: early stopping criterion
556 * \return double value representing class. You can use the
557 * predictLabels() function together with the
558 * rf.external_parameter().class_type_ attribute
559 * to get back the same type used during learning.
560 */
561 template <class U, class C, class Stop>
562 LabelType predictLabel(MultiArrayView<2, U, C>const & features, Stop & stop) const;
563
564 template <class U, class C>
predictLabel(MultiArrayView<2,U,C> const & features)565 LabelType predictLabel(MultiArrayView<2, U, C>const & features)
566 {
567 return predictLabel(features, rf_default());
568 }
569 /** \brief predict a label with features and class priors
570 *
571 * \param features: same as above.
572 * \param prior: iterator to prior weighting of classes
573 * \return sam as above.
574 */
575 template <class U, class C>
576 LabelType predictLabel(MultiArrayView<2, U, C> const & features,
577 ArrayVectorView<double> prior) const;
578
579 /** \brief predict multiple labels with given features
580 *
581 * \param features: a n by featureCount matrix containing
582 * data point to be predicted (this only works in
583 * classification setting)
584 * \param labels: a n by 1 matrix passed by reference to store
585 * output.
586 *
587 * If the input contains an NaN value, an precondition exception is thrown.
588 */
589 template <class U, class C1, class T, class C2>
predictLabels(MultiArrayView<2,U,C1> const & features,MultiArrayView<2,T,C2> & labels) const590 void predictLabels(MultiArrayView<2, U, C1>const & features,
591 MultiArrayView<2, T, C2> & labels) const
592 {
593 vigra_precondition(features.shape(0) == labels.shape(0),
594 "RandomForest::predictLabels(): Label array has wrong size.");
595 for(int k=0; k<features.shape(0); ++k)
596 {
597 vigra_precondition(!detail::contains_nan(rowVector(features, k)),
598 "RandomForest::predictLabels(): NaN in feature matrix.");
599 labels(k,0) = detail::RequiresExplicitCast<T>::cast(predictLabel(rowVector(features, k), rf_default()));
600 }
601 }
602
603 /** \brief predict multiple labels with given features
604 *
605 * \param features: a n by featureCount matrix containing
606 * data point to be predicted (this only works in
607 * classification setting)
608 * \param labels: a n by 1 matrix passed by reference to store
609 * output.
610 * \param nanLabel: label to be returned for the row of the input that
611 * contain an NaN value.
612 */
613 template <class U, class C1, class T, class C2>
predictLabels(MultiArrayView<2,U,C1> const & features,MultiArrayView<2,T,C2> & labels,LabelType nanLabel) const614 void predictLabels(MultiArrayView<2, U, C1>const & features,
615 MultiArrayView<2, T, C2> & labels,
616 LabelType nanLabel) const
617 {
618 vigra_precondition(features.shape(0) == labels.shape(0),
619 "RandomForest::predictLabels(): Label array has wrong size.");
620 for(int k=0; k<features.shape(0); ++k)
621 {
622 if(detail::contains_nan(rowVector(features, k)))
623 labels(k,0) = nanLabel;
624 else
625 labels(k,0) = detail::RequiresExplicitCast<T>::cast(predictLabel(rowVector(features, k), rf_default()));
626 }
627 }
628
629 /** \brief predict multiple labels with given features
630 *
631 * \param features: a n by featureCount matrix containing
632 * data point to be predicted (this only works in
633 * classification setting)
634 * \param labels: a n by 1 matrix passed by reference to store
635 * output.
636 * \param stop: an early stopping criterion.
637 */
638 template <class U, class C1, class T, class C2, class Stop>
predictLabels(MultiArrayView<2,U,C1> const & features,MultiArrayView<2,T,C2> & labels,Stop & stop) const639 void predictLabels(MultiArrayView<2, U, C1>const & features,
640 MultiArrayView<2, T, C2> & labels,
641 Stop & stop) const
642 {
643 vigra_precondition(features.shape(0) == labels.shape(0),
644 "RandomForest::predictLabels(): Label array has wrong size.");
645 for(int k=0; k<features.shape(0); ++k)
646 labels(k,0) = detail::RequiresExplicitCast<T>::cast(predictLabel(rowVector(features, k), stop));
647 }
648 /** \brief predict the class probabilities for multiple labels
649 *
650 * \param features same as above
651 * \param prob a n x class_count_ matrix. passed by reference to
652 * save class probabilities
653 * \param stop earlystopping criterion
654 * \sa EarlyStopping
655
656 When a row of the feature array contains an NaN, the corresponding instance
657 cannot belong to any of the classes. The corresponding row in the probability
658 array will therefore contain all zeros.
659 */
660 template <class U, class C1, class T, class C2, class Stop>
661 void predictProbabilities(MultiArrayView<2, U, C1>const & features,
662 MultiArrayView<2, T, C2> & prob,
663 Stop & stop) const;
664 template <class T1,class T2, class C>
665 void predictProbabilities(OnlinePredictionSet<T1> & predictionSet,
666 MultiArrayView<2, T2, C> & prob);
667
668 /** \brief predict the class probabilities for multiple labels
669 *
670 * \param features same as above
671 * \param prob a n x class_count_ matrix. passed by reference to
672 * save class probabilities
673 */
674 template <class U, class C1, class T, class C2>
predictProbabilities(MultiArrayView<2,U,C1> const & features,MultiArrayView<2,T,C2> & prob) const675 void predictProbabilities(MultiArrayView<2, U, C1>const & features,
676 MultiArrayView<2, T, C2> & prob) const
677 {
678 predictProbabilities(features, prob, rf_default());
679 }
680
681 template <class U, class C1, class T, class C2>
682 void predictRaw(MultiArrayView<2, U, C1>const & features,
683 MultiArrayView<2, T, C2> & prob) const;
684
685
686 /** @} */
687
688 };
689
690
691 template <class LabelType, class PreprocessorTag>
692 template<class U,class C1,
693 class U2, class C2,
694 class Split_t,
695 class Stop_t,
696 class Visitor_t,
697 class Random_t>
onlineLearn(MultiArrayView<2,U,C1> const & features,MultiArrayView<2,U2,C2> const & response,int new_start_index,Visitor_t visitor_,Split_t split_,Stop_t stop_,Random_t & random,bool adjust_thresholds)698 void RandomForest<LabelType, PreprocessorTag>::onlineLearn(MultiArrayView<2,U,C1> const & features,
699 MultiArrayView<2,U2,C2> const & response,
700 int new_start_index,
701 Visitor_t visitor_,
702 Split_t split_,
703 Stop_t stop_,
704 Random_t & random,
705 bool adjust_thresholds)
706 {
707 online_visitor_.activate();
708 online_visitor_.adjust_thresholds=adjust_thresholds;
709
710 using namespace rf;
711 //typedefs
712 typedef Processor<PreprocessorTag,LabelType,U,C1,U2,C2> Preprocessor_t;
713 typedef UniformIntRandomFunctor<Random_t>
714 RandFunctor_t;
715 // default values and initialization
716 // Value Chooser chooses second argument as value if first argument
717 // is of type RF_DEFAULT. (thanks to template magic - don't care about
718 // it - just smile and wave.
719
720 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
721 Default_Stop_t default_stop(options_);
722 typename RF_CHOOSER(Stop_t)::type stop
723 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
724 Default_Split_t default_split;
725 typename RF_CHOOSER(Split_t)::type split
726 = RF_CHOOSER(Split_t)::choose(split_, default_split);
727 rf::visitors::StopVisiting stopvisiting;
728 typedef rf::visitors::detail::VisitorNode
729 <rf::visitors::OnlineLearnVisitor,
730 typename RF_CHOOSER(Visitor_t)::type>
731 IntermedVis;
732 IntermedVis
733 visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
734 #undef RF_CHOOSER
735 vigra_precondition(options_.prepare_online_learning_,"onlineLearn: online learning must be enabled on RandomForest construction");
736
737 // Preprocess the data to get something the split functor can work
738 // with. Also fill the ext_param structure by preprocessing
739 // option parameters that could only be completely evaluated
740 // when the training data is known.
741 ext_param_.class_count_=0;
742 Preprocessor_t preprocessor( features, response,
743 options_, ext_param_);
744
745 // Make stl compatible random functor.
746 RandFunctor_t randint ( random);
747
748 // Give the Split functor information about the data.
749 split.set_external_parameters(ext_param_);
750 stop.set_external_parameters(ext_param_);
751
752
753 //Create poisson samples
754 PoissonSampler<RandomTT800> poisson_sampler(1.0,vigra::Int32(new_start_index),vigra::Int32(ext_param().row_count_));
755
756 //TODO: visitors for online learning
757 //visitor.visit_at_beginning(*this, preprocessor);
758
759 // THE MAIN EFFING RF LOOP - YEAY DUDE!
760 for(int ii = 0; ii < static_cast<int>(trees_.size()); ++ii)
761 {
762 online_visitor_.tree_id=ii;
763 poisson_sampler.sample();
764 std::map<int,int> leaf_parents;
765 leaf_parents.clear();
766 //Get all the leaf nodes for that sample
767 for(int s=0;s<poisson_sampler.numOfSamples();++s)
768 {
769 int sample=poisson_sampler[s];
770 online_visitor_.current_label=preprocessor.response()(sample,0);
771 online_visitor_.last_node_id=StackEntry_t::DecisionTreeNoParent;
772 int leaf=trees_[ii].getToLeaf(rowVector(features,sample),online_visitor_);
773
774
775 //Add to the list for that leaf
776 online_visitor_.add_to_index_list(ii,leaf,sample);
777 //TODO: Class count?
778 //Store parent
779 if(Node<e_ConstProbNode>(trees_[ii].topology_,trees_[ii].parameters_,leaf).prob_begin()[preprocessor.response()(sample,0)]!=1.0)
780 {
781 leaf_parents[leaf]=online_visitor_.last_node_id;
782 }
783 }
784
785
786 std::map<int,int>::iterator leaf_iterator;
787 for(leaf_iterator=leaf_parents.begin();leaf_iterator!=leaf_parents.end();++leaf_iterator)
788 {
789 int leaf=leaf_iterator->first;
790 int parent=leaf_iterator->second;
791 int lin_index=online_visitor_.trees_online_information[ii].exterior_to_index[leaf];
792 ArrayVector<Int32> indeces;
793 indeces.clear();
794 indeces.swap(online_visitor_.trees_online_information[ii].index_lists[lin_index]);
795 StackEntry_t stack_entry(indeces.begin(),
796 indeces.end(),
797 ext_param_.class_count_);
798
799
800 if(parent!=-1)
801 {
802 if(NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(0)==leaf)
803 {
804 stack_entry.leftParent=parent;
805 }
806 else
807 {
808 vigra_assert(NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(1)==leaf,"last_node_id seems to be wrong");
809 stack_entry.rightParent=parent;
810 }
811 }
812 //trees_[ii].continueLearn(preprocessor.features(),preprocessor.response(),stack_entry,split,stop,visitor,randint,leaf);
813 trees_[ii].continueLearn(preprocessor.features(),preprocessor.response(),stack_entry,split,stop,visitor,randint,-1);
814 //Now, the last one moved onto leaf
815 online_visitor_.move_exterior_node(ii,trees_[ii].topology_.size(),ii,leaf);
816 //Now it should be classified correctly!
817 }
818
819 /*visitor
820 .visit_after_tree( *this,
821 preprocessor,
822 poisson_sampler,
823 stack_entry,
824 ii);*/
825 }
826
827 //visitor.visit_at_end(*this, preprocessor);
828 online_visitor_.deactivate();
829 }
830
831 template<class LabelType, class PreprocessorTag>
832 template<class U,class C1,
833 class U2, class C2,
834 class Split_t,
835 class Stop_t,
836 class Visitor_t,
837 class Random_t>
reLearnTree(MultiArrayView<2,U,C1> const & features,MultiArrayView<2,U2,C2> const & response,int treeId,Visitor_t visitor_,Split_t split_,Stop_t stop_,Random_t & random)838 void RandomForest<LabelType, PreprocessorTag>::reLearnTree(MultiArrayView<2,U,C1> const & features,
839 MultiArrayView<2,U2,C2> const & response,
840 int treeId,
841 Visitor_t visitor_,
842 Split_t split_,
843 Stop_t stop_,
844 Random_t & random)
845 {
846 using namespace rf;
847
848
849 typedef UniformIntRandomFunctor<Random_t>
850 RandFunctor_t;
851
852 // See rf_preprocessing.hxx for more info on this
853 ext_param_.class_count_=0;
854 typedef Processor<PreprocessorTag,LabelType, U, C1, U2, C2> Preprocessor_t;
855
856 // default values and initialization
857 // Value Chooser chooses second argument as value if first argument
858 // is of type RF_DEFAULT. (thanks to template magic - don't care about
859 // it - just smile and wave.
860
861 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
862 Default_Stop_t default_stop(options_);
863 typename RF_CHOOSER(Stop_t)::type stop
864 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
865 Default_Split_t default_split;
866 typename RF_CHOOSER(Split_t)::type split
867 = RF_CHOOSER(Split_t)::choose(split_, default_split);
868 rf::visitors::StopVisiting stopvisiting;
869 typedef rf::visitors::detail::VisitorNode
870 <rf::visitors::OnlineLearnVisitor,
871 typename RF_CHOOSER(Visitor_t)::type> IntermedVis;
872 IntermedVis
873 visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
874 #undef RF_CHOOSER
875 vigra_precondition(options_.prepare_online_learning_,"reLearnTree: Re learning trees only makes sense, if online learning is enabled");
876 online_visitor_.activate();
877
878 // Make stl compatible random functor.
879 RandFunctor_t randint ( random);
880
881 // Preprocess the data to get something the split functor can work
882 // with. Also fill the ext_param structure by preprocessing
883 // option parameters that could only be completely evaluated
884 // when the training data is known.
885 Preprocessor_t preprocessor( features, response,
886 options_, ext_param_);
887
888 // Give the Split functor information about the data.
889 split.set_external_parameters(ext_param_);
890 stop.set_external_parameters(ext_param_);
891
892 /**\todo replace this crappy class out. It uses function pointers.
893 * and is making code slower according to me.
894 * Comment from Nathan: This is copied from Rahul, so me=Rahul
895 */
896 Sampler<Random_t > sampler(preprocessor.strata().begin(),
897 preprocessor.strata().end(),
898 detail::make_sampler_opt(options_)
899 .sampleSize(ext_param().actual_msample_),
900 &random);
901 //initialize First region/node/stack entry
902 sampler
903 .sample();
904
905 StackEntry_t
906 first_stack_entry( sampler.sampledIndices().begin(),
907 sampler.sampledIndices().end(),
908 ext_param_.class_count_);
909 first_stack_entry
910 .set_oob_range( sampler.oobIndices().begin(),
911 sampler.oobIndices().end());
912 online_visitor_.reset_tree(treeId);
913 online_visitor_.tree_id=treeId;
914 trees_[treeId].reset();
915 trees_[treeId]
916 .learn( preprocessor.features(),
917 preprocessor.response(),
918 first_stack_entry,
919 split,
920 stop,
921 visitor,
922 randint);
923 visitor
924 .visit_after_tree( *this,
925 preprocessor,
926 sampler,
927 first_stack_entry,
928 treeId);
929
930 online_visitor_.deactivate();
931 }
932
933 template <class LabelType, class PreprocessorTag>
934 template <class U, class C1,
935 class U2,class C2,
936 class Split_t,
937 class Stop_t,
938 class Visitor_t,
939 class Random_t>
940 void RandomForest<LabelType, PreprocessorTag>::
learn(MultiArrayView<2,U,C1> const & features,MultiArrayView<2,U2,C2> const & response,Visitor_t visitor_,Split_t split_,Stop_t stop_,Random_t const & random)941 learn( MultiArrayView<2, U, C1> const & features,
942 MultiArrayView<2, U2,C2> const & response,
943 Visitor_t visitor_,
944 Split_t split_,
945 Stop_t stop_,
946 Random_t const & random)
947 {
948 using namespace rf;
949 //this->reset();
950 //typedefs
951 typedef UniformIntRandomFunctor<Random_t>
952 RandFunctor_t;
953
954 // See rf_preprocessing.hxx for more info on this
955 typedef Processor<PreprocessorTag,LabelType, U, C1, U2, C2> Preprocessor_t;
956
957 vigra_precondition(features.shape(0) == response.shape(0),
958 "RandomForest::learn(): shape mismatch between features and response.");
959
960 // default values and initialization
961 // Value Chooser chooses second argument as value if first argument
962 // is of type RF_DEFAULT. (thanks to template magic - don't care about
963 // it - just smile and wave).
964
965 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
966 Default_Stop_t default_stop(options_);
967 typename RF_CHOOSER(Stop_t)::type stop
968 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
969 Default_Split_t default_split;
970 typename RF_CHOOSER(Split_t)::type split
971 = RF_CHOOSER(Split_t)::choose(split_, default_split);
972 rf::visitors::StopVisiting stopvisiting;
973 typedef rf::visitors::detail::VisitorNode<
974 rf::visitors::OnlineLearnVisitor,
975 typename RF_CHOOSER(Visitor_t)::type> IntermedVis;
976 IntermedVis
977 visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
978 #undef RF_CHOOSER
979 if(options_.prepare_online_learning_)
980 online_visitor_.activate();
981 else
982 online_visitor_.deactivate();
983
984
985 // Make stl compatible random functor.
986 RandFunctor_t randint ( random);
987
988
989 // Preprocess the data to get something the split functor can work
990 // with. Also fill the ext_param structure by preprocessing
991 // option parameters that could only be completely evaluated
992 // when the training data is known.
993 Preprocessor_t preprocessor( features, response,
994 options_, ext_param_);
995
996 // Give the Split functor information about the data.
997 split.set_external_parameters(ext_param_);
998 stop.set_external_parameters(ext_param_);
999
1000
1001 //initialize trees.
1002 trees_.resize(options_.tree_count_ , DecisionTree_t(ext_param_));
1003
1004 Sampler<Random_t > sampler(preprocessor.strata().begin(),
1005 preprocessor.strata().end(),
1006 detail::make_sampler_opt(options_)
1007 .sampleSize(ext_param().actual_msample_),
1008 &random);
1009
1010 visitor.visit_at_beginning(*this, preprocessor);
1011 // THE MAIN EFFING RF LOOP - YEAY DUDE!
1012
1013 for(int ii = 0; ii < static_cast<int>(trees_.size()); ++ii)
1014 {
1015 //initialize First region/node/stack entry
1016 sampler
1017 .sample();
1018 StackEntry_t
1019 first_stack_entry( sampler.sampledIndices().begin(),
1020 sampler.sampledIndices().end(),
1021 ext_param_.class_count_);
1022 first_stack_entry
1023 .set_oob_range( sampler.oobIndices().begin(),
1024 sampler.oobIndices().end());
1025 trees_[ii]
1026 .learn( preprocessor.features(),
1027 preprocessor.response(),
1028 first_stack_entry,
1029 split,
1030 stop,
1031 visitor,
1032 randint);
1033 visitor
1034 .visit_after_tree( *this,
1035 preprocessor,
1036 sampler,
1037 first_stack_entry,
1038 ii);
1039 }
1040
1041 visitor.visit_at_end(*this, preprocessor);
1042 // Only for online learning?
1043 online_visitor_.deactivate();
1044 }
1045
1046
1047
1048
1049 template <class LabelType, class Tag>
1050 template <class U, class C, class Stop>
1051 LabelType RandomForest<LabelType, Tag>
predictLabel(MultiArrayView<2,U,C> const & features,Stop & stop) const1052 ::predictLabel(MultiArrayView<2, U, C> const & features, Stop & stop) const
1053 {
1054 vigra_precondition(columnCount(features) >= ext_param_.column_count_,
1055 "RandomForestn::predictLabel():"
1056 " Too few columns in feature matrix.");
1057 vigra_precondition(rowCount(features) == 1,
1058 "RandomForestn::predictLabel():"
1059 " Feature matrix must have a singlerow.");
1060 MultiArray<2, double> probabilities(Shape2(1, ext_param_.class_count_), 0.0);
1061 LabelType d;
1062 predictProbabilities(features, probabilities, stop);
1063 ext_param_.to_classlabel(argMax(probabilities), d);
1064 return d;
1065 }
1066
1067
1068 //Same thing as above with priors for each label !!!
1069 template <class LabelType, class PreprocessorTag>
1070 template <class U, class C>
1071 LabelType RandomForest<LabelType, PreprocessorTag>
predictLabel(MultiArrayView<2,U,C> const & features,ArrayVectorView<double> priors) const1072 ::predictLabel( MultiArrayView<2, U, C> const & features,
1073 ArrayVectorView<double> priors) const
1074 {
1075 using namespace functor;
1076 vigra_precondition(columnCount(features) >= ext_param_.column_count_,
1077 "RandomForestn::predictLabel(): Too few columns in feature matrix.");
1078 vigra_precondition(rowCount(features) == 1,
1079 "RandomForestn::predictLabel():"
1080 " Feature matrix must have a single row.");
1081 Matrix<double> prob(1,ext_param_.class_count_);
1082 predictProbabilities(features, prob);
1083 std::transform( prob.begin(), prob.end(),
1084 priors.begin(), prob.begin(),
1085 Arg1()*Arg2());
1086 LabelType d;
1087 ext_param_.to_classlabel(argMax(prob), d);
1088 return d;
1089 }
1090
1091 template<class LabelType,class PreprocessorTag>
1092 template <class T1,class T2, class C>
1093 void RandomForest<LabelType,PreprocessorTag>
predictProbabilities(OnlinePredictionSet<T1> & predictionSet,MultiArrayView<2,T2,C> & prob)1094 ::predictProbabilities(OnlinePredictionSet<T1> & predictionSet,
1095 MultiArrayView<2, T2, C> & prob)
1096 {
1097 //Features are n xp
1098 //prob is n x NumOfLabel probability for each feature in each class
1099
1100 vigra_precondition(rowCount(predictionSet.features) == rowCount(prob),
1101 "RandomFroest::predictProbabilities():"
1102 " Feature matrix and probability matrix size mismatch.");
1103 // num of features must be bigger than num of features in Random forest training
1104 // but why bigger?
1105 vigra_precondition( columnCount(predictionSet.features) >= ext_param_.column_count_,
1106 "RandomForestn::predictProbabilities():"
1107 " Too few columns in feature matrix.");
1108 vigra_precondition( columnCount(prob)
1109 == static_cast<MultiArrayIndex>(ext_param_.class_count_),
1110 "RandomForestn::predictProbabilities():"
1111 " Probability matrix must have as many columns as there are classes.");
1112 prob.init(0.0);
1113 //store total weights
1114 std::vector<T1> totalWeights(predictionSet.indices[0].size(),0.0);
1115 //Go through all trees
1116 int set_id=-1;
1117 for(int k=0; k<options_.tree_count_; ++k)
1118 {
1119 set_id=(set_id+1) % predictionSet.indices[0].size();
1120 typedef std::set<SampleRange<T1> > my_set;
1121 typedef typename my_set::iterator set_it;
1122 //typedef std::set<std::pair<int,SampleRange<T1> > >::iterator set_it;
1123 //Build a stack with all the ranges we have
1124 std::vector<std::pair<int,set_it> > stack;
1125 stack.clear();
1126 for(set_it i=predictionSet.ranges[set_id].begin();
1127 i!=predictionSet.ranges[set_id].end();++i)
1128 stack.push_back(std::pair<int,set_it>(2,i));
1129 //get weights predicted by single tree
1130 int num_decisions=0;
1131 while(!stack.empty())
1132 {
1133 set_it range=stack.back().second;
1134 int index=stack.back().first;
1135 stack.pop_back();
1136 ++num_decisions;
1137
1138 if(trees_[k].isLeafNode(trees_[k].topology_[index]))
1139 {
1140 ArrayVector<double>::iterator weights=Node<e_ConstProbNode>(trees_[k].topology_,
1141 trees_[k].parameters_,
1142 index).prob_begin();
1143 for(int i=range->start;i!=range->end;++i)
1144 {
1145 //update votecount.
1146 for(int l=0; l<ext_param_.class_count_; ++l)
1147 {
1148 prob(predictionSet.indices[set_id][i], l) += static_cast<T2>(weights[l]);
1149 //every weight in totalWeight.
1150 totalWeights[predictionSet.indices[set_id][i]] += static_cast<T1>(weights[l]);
1151 }
1152 }
1153 }
1154
1155 else
1156 {
1157 if(trees_[k].topology_[index]!=i_ThresholdNode)
1158 {
1159 throw std::runtime_error("predicting with online prediction sets is only supported for RFs with threshold nodes");
1160 }
1161 Node<i_ThresholdNode> node(trees_[k].topology_,trees_[k].parameters_,index);
1162 if(range->min_boundaries[node.column()]>=node.threshold())
1163 {
1164 //Everything goes to right child
1165 stack.push_back(std::pair<int,set_it>(node.child(1),range));
1166 continue;
1167 }
1168 if(range->max_boundaries[node.column()]<node.threshold())
1169 {
1170 //Everything goes to the left child
1171 stack.push_back(std::pair<int,set_it>(node.child(0),range));
1172 continue;
1173 }
1174 //We have to split at this node
1175 SampleRange<T1> new_range=*range;
1176 new_range.min_boundaries[node.column()]=FLT_MAX;
1177 range->max_boundaries[node.column()]=-FLT_MAX;
1178 new_range.start=new_range.end=range->end;
1179 int i=range->start;
1180 while(i!=range->end)
1181 {
1182 //Decide for range->indices[i]
1183 if(predictionSet.features(predictionSet.indices[set_id][i],node.column())>=node.threshold())
1184 {
1185 new_range.min_boundaries[node.column()]=std::min(new_range.min_boundaries[node.column()],
1186 predictionSet.features(predictionSet.indices[set_id][i],node.column()));
1187 --range->end;
1188 --new_range.start;
1189 std::swap(predictionSet.indices[set_id][i],predictionSet.indices[set_id][range->end]);
1190
1191 }
1192 else
1193 {
1194 range->max_boundaries[node.column()]=std::max(range->max_boundaries[node.column()],
1195 predictionSet.features(predictionSet.indices[set_id][i],node.column()));
1196 ++i;
1197 }
1198 }
1199 //The old one ...
1200 if(range->start==range->end)
1201 {
1202 predictionSet.ranges[set_id].erase(range);
1203 }
1204 else
1205 {
1206 stack.push_back(std::pair<int,set_it>(node.child(0),range));
1207 }
1208 //And the new one ...
1209 if(new_range.start!=new_range.end)
1210 {
1211 std::pair<set_it,bool> new_it=predictionSet.ranges[set_id].insert(new_range);
1212 stack.push_back(std::pair<int,set_it>(node.child(1),new_it.first));
1213 }
1214 }
1215 }
1216 predictionSet.cumulativePredTime[k]=num_decisions;
1217 }
1218 for(unsigned int i=0;i<totalWeights.size();++i)
1219 {
1220 double test=0.0;
1221 //Normalise votes in each row by total VoteCount (totalWeight
1222 for(int l=0; l<ext_param_.class_count_; ++l)
1223 {
1224 test+=prob(i,l);
1225 prob(i, l) /= totalWeights[i];
1226 }
1227 assert(test==totalWeights[i]);
1228 assert(totalWeights[i]>0.0);
1229 }
1230 }
1231
1232 template <class LabelType, class PreprocessorTag>
1233 template <class U, class C1, class T, class C2, class Stop_t>
1234 void RandomForest<LabelType, PreprocessorTag>
predictProbabilities(MultiArrayView<2,U,C1> const & features,MultiArrayView<2,T,C2> & prob,Stop_t & stop_) const1235 ::predictProbabilities(MultiArrayView<2, U, C1>const & features,
1236 MultiArrayView<2, T, C2> & prob,
1237 Stop_t & stop_) const
1238 {
1239 //Features are n xp
1240 //prob is n x NumOfLabel probability for each feature in each class
1241
1242 vigra_precondition(rowCount(features) == rowCount(prob),
1243 "RandomForestn::predictProbabilities():"
1244 " Feature matrix and probability matrix size mismatch.");
1245
1246 // num of features must be bigger than num of features in Random forest training
1247 // but why bigger?
1248 vigra_precondition( columnCount(features) >= ext_param_.column_count_,
1249 "RandomForestn::predictProbabilities():"
1250 " Too few columns in feature matrix.");
1251 vigra_precondition( columnCount(prob)
1252 == static_cast<MultiArrayIndex>(ext_param_.class_count_),
1253 "RandomForestn::predictProbabilities():"
1254 " Probability matrix must have as many columns as there are classes.");
1255
1256 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
1257 Default_Stop_t default_stop(options_);
1258 typename RF_CHOOSER(Stop_t)::type & stop
1259 = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
1260 #undef RF_CHOOSER
1261 stop.set_external_parameters(ext_param_, tree_count());
1262 prob.init(NumericTraits<T>::zero());
1263 /* This code was originally there for testing early stopping
1264 * - we wanted the order of the trees to be randomized
1265 if(tree_indices_.size() != 0)
1266 {
1267 std::random_shuffle(tree_indices_.begin(),
1268 tree_indices_.end());
1269 }
1270 */
1271 //Classify for each row.
1272 for(int row=0; row < rowCount(features); ++row)
1273 {
1274 MultiArrayView<2, U, StridedArrayTag> currentRow(rowVector(features, row));
1275
1276 // when the features contain an NaN, the instance doesn't belong to any class
1277 // => indicate this by returning a zero probability array.
1278 if(detail::contains_nan(currentRow))
1279 {
1280 rowVector(prob, row).init(0.0);
1281 continue;
1282 }
1283
1284 ArrayVector<double>::const_iterator weights;
1285
1286 //totalWeight == totalVoteCount!
1287 double totalWeight = 0.0;
1288
1289 //Let each tree classify...
1290 for(int k=0; k<options_.tree_count_; ++k)
1291 {
1292 //get weights predicted by single tree
1293 weights = trees_[k /*tree_indices_[k]*/].predict(currentRow);
1294
1295 //update votecount.
1296 int weighted = options_.predict_weighted_;
1297 for(int l=0; l<ext_param_.class_count_; ++l)
1298 {
1299 double cur_w = weights[l] * (weighted * (*(weights-1))
1300 + (1-weighted));
1301 prob(row, l) += static_cast<T>(cur_w);
1302 //every weight in totalWeight.
1303 totalWeight += cur_w;
1304 }
1305 if(stop.after_prediction(weights,
1306 k,
1307 rowVector(prob, row),
1308 totalWeight))
1309 {
1310 break;
1311 }
1312 }
1313
1314 //Normalise votes in each row by total VoteCount (totalWeight
1315 for(int l=0; l< ext_param_.class_count_; ++l)
1316 {
1317 prob(row, l) /= detail::RequiresExplicitCast<T>::cast(totalWeight);
1318 }
1319 }
1320
1321 }
1322
1323 template <class LabelType, class PreprocessorTag>
1324 template <class U, class C1, class T, class C2>
1325 void RandomForest<LabelType, PreprocessorTag>
predictRaw(MultiArrayView<2,U,C1> const & features,MultiArrayView<2,T,C2> & prob) const1326 ::predictRaw(MultiArrayView<2, U, C1>const & features,
1327 MultiArrayView<2, T, C2> & prob) const
1328 {
1329 //Features are n xp
1330 //prob is n x NumOfLabel probability for each feature in each class
1331
1332 vigra_precondition(rowCount(features) == rowCount(prob),
1333 "RandomForestn::predictProbabilities():"
1334 " Feature matrix and probability matrix size mismatch.");
1335
1336 // num of features must be bigger than num of features in Random forest training
1337 // but why bigger?
1338 vigra_precondition( columnCount(features) >= ext_param_.column_count_,
1339 "RandomForestn::predictProbabilities():"
1340 " Too few columns in feature matrix.");
1341 vigra_precondition( columnCount(prob)
1342 == static_cast<MultiArrayIndex>(ext_param_.class_count_),
1343 "RandomForestn::predictProbabilities():"
1344 " Probability matrix must have as many columns as there are classes.");
1345
1346 #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
1347 prob.init(NumericTraits<T>::zero());
1348 /* This code was originally there for testing early stopping
1349 * - we wanted the order of the trees to be randomized
1350 if(tree_indices_.size() != 0)
1351 {
1352 std::random_shuffle(tree_indices_.begin(),
1353 tree_indices_.end());
1354 }
1355 */
1356 //Classify for each row.
1357 for(int row=0; row < rowCount(features); ++row)
1358 {
1359 ArrayVector<double>::const_iterator weights;
1360
1361 //totalWeight == totalVoteCount!
1362 double totalWeight = 0.0;
1363
1364 //Let each tree classify...
1365 for(int k=0; k<options_.tree_count_; ++k)
1366 {
1367 //get weights predicted by single tree
1368 weights = trees_[k /*tree_indices_[k]*/].predict(rowVector(features, row));
1369
1370 //update votecount.
1371 int weighted = options_.predict_weighted_;
1372 for(int l=0; l<ext_param_.class_count_; ++l)
1373 {
1374 double cur_w = weights[l] * (weighted * (*(weights-1))
1375 + (1-weighted));
1376 prob(row, l) += static_cast<T>(cur_w);
1377 //every weight in totalWeight.
1378 totalWeight += cur_w;
1379 }
1380 }
1381 }
1382 prob/= options_.tree_count_;
1383
1384 }
1385
1386 } // namespace vigra
1387
1388 #include "random_forest/rf_algorithm.hxx"
1389 #endif // VIGRA_RANDOM_FOREST_HXX
1390