1 /************************************************************************/
2 /*                                                                      */
3 /*        Copyright 2008-2009 by  Ullrich Koethe and Rahul Nair         */
4 /*                                                                      */
5 /*    This file is part of the VIGRA computer vision library.           */
6 /*    The VIGRA Website is                                              */
7 /*        http://hci.iwr.uni-heidelberg.de/vigra/                       */
8 /*    Please direct questions, bug reports, and contributions to        */
9 /*        ullrich.koethe@iwr.uni-heidelberg.de    or                    */
10 /*        vigra@informatik.uni-hamburg.de                               */
11 /*                                                                      */
12 /*    Permission is hereby granted, free of charge, to any person       */
13 /*    obtaining a copy of this software and associated documentation    */
14 /*    files (the "Software"), to deal in the Software without           */
15 /*    restriction, including without limitation the rights to use,      */
16 /*    copy, modify, merge, publish, distribute, sublicense, and/or      */
17 /*    sell copies of the Software, and to permit persons to whom the    */
18 /*    Software is furnished to do so, subject to the following          */
19 /*    conditions:                                                       */
20 /*                                                                      */
21 /*    The above copyright notice and this permission notice shall be    */
22 /*    included in all copies or substantial portions of the             */
23 /*    Software.                                                         */
24 /*                                                                      */
25 /*    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND    */
26 /*    EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES   */
27 /*    OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND          */
28 /*    NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT       */
29 /*    HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,      */
30 /*    WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING      */
31 /*    FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR     */
32 /*    OTHER DEALINGS IN THE SOFTWARE.                                   */
33 /*                                                                      */
34 /************************************************************************/
35 
36 
37 #ifndef VIGRA_RANDOM_FOREST_HXX
38 #define VIGRA_RANDOM_FOREST_HXX
39 
40 #include <iostream>
41 #include <algorithm>
42 #include <map>
43 #include <set>
44 #include <list>
45 #include <numeric>
46 #include "mathutil.hxx"
47 #include "array_vector.hxx"
48 #include "sized_int.hxx"
49 #include "matrix.hxx"
50 #include "metaprogramming.hxx"
51 #include "random.hxx"
52 #include "functorexpression.hxx"
53 #include "random_forest/rf_common.hxx"
54 #include "random_forest/rf_nodeproxy.hxx"
55 #include "random_forest/rf_split.hxx"
56 #include "random_forest/rf_decisionTree.hxx"
57 #include "random_forest/rf_visitors.hxx"
58 #include "random_forest/rf_region.hxx"
59 #include "sampling.hxx"
60 #include "random_forest/rf_preprocessing.hxx"
61 #include "random_forest/rf_online_prediction_set.hxx"
62 #include "random_forest/rf_earlystopping.hxx"
63 #include "random_forest/rf_ridge_split.hxx"
64 namespace vigra
65 {
66 
67 /** \addtogroup MachineLearning Machine Learning
68 
69     This module provides classification algorithms that map
70     features to labels or label probabilities.
71     Look at the \ref vigra::RandomForest class (for implementation version 2) or the
72     \ref vigra::rf3::random_forest() factory function (for implementation version 3)
73     for an overview of the functionality as well as use cases.
74 **/
75 
76 namespace detail
77 {
78 
79 
80 
81 /* \brief sampling option factory function
82  */
make_sampler_opt(RandomForestOptions & RF_opt)83 inline SamplerOptions make_sampler_opt ( RandomForestOptions     & RF_opt)
84 {
85     SamplerOptions return_opt;
86     return_opt.withReplacement(RF_opt.sample_with_replacement_);
87     return_opt.stratified(RF_opt.stratification_method_ == RF_EQUAL);
88     return return_opt;
89 }
90 }//namespace detail
91 
92 /** \brief Random forest version 2 (see also \ref vigra::rf3::RandomForest for version 3)
93  *
94  * \ingroup MachineLearning
95  *
96  * \tparam <LabelType = double> Type used for predicted labels.
97  * \tparam <PreprocessorTag = ClassificationTag> Class used to preprocess
98  *          the input while learning and predicting. Currently Available:
99  *          ClassificationTag and RegressionTag. It is recommended to use
100  *          Splitfunctor::Preprocessor_t while using custom splitfunctors
101  *          as they may need the data to be in a different format.
102  *          \sa Preprocessor
103  *
104  *  Simple usage for classification (regression is not yet supported):
105  *  look at RandomForest::learn() as well as RandomForestOptions() for additional
106  *  options.
107  *
108  *  \code
109  *  using namespace vigra;
110  *  using namespace rf;
111  *  typedef xxx feature_t; \\ replace xxx with whichever type
112  *  typedef yyy label_t;   \\ likewise
113  *
114  *  // allocate the training data
115  *  MultiArrayView<2, feature_t> f = get_training_features();
116  *  MultiArrayView<2, label_t>   l = get_training_labels();
117  *
118  *  RandomForest<label_t> rf;
119  *
120  *  // construct visitor to calculate out-of-bag error
121  *  visitors::OOB_Error oob_v;
122  *
123  *  // perform training
124  *  rf.learn(f, l, visitors::create_visitor(oob_v));
125  *
126  *  std::cout << "the out-of-bag error is: " << oob_v.oob_breiman << "\n";
127  *
128  *  // get features for new data to be used for prediction
129  *  MultiArrayView<2, feature_t> pf = get_features();
130  *
131  *  // allocate space for the response (pf.shape(0) is the number of samples)
132  *  MultiArrayView<2, label_t> prediction(pf.shape(0), 1);
133  *  MultiArrayView<2, double> prob(pf.shape(0), rf.class_count());
134  *
135  *  // perform prediction on new data
136  *  rf.predictLabels(pf, prediction);
137  *  rf.predictProbabilities(pf, prob);
138  *
139  *  \endcode
140  *
141  *  Additional information such as Variable Importance measures are accessed
142  *  via Visitors defined in rf::visitors.
143  *  Have a look at rf::split for other splitting methods.
144  *
145 */
146 template <class LabelType = double , class PreprocessorTag = ClassificationTag >
147 class RandomForest
148 {
149 
150   public:
151     //public typedefs
152     typedef RandomForestOptions             Options_t;
153     typedef detail::DecisionTree            DecisionTree_t;
154     typedef ProblemSpec<LabelType>          ProblemSpec_t;
155     typedef GiniSplit                       Default_Split_t;
156     typedef EarlyStoppStd                   Default_Stop_t;
157     typedef rf::visitors::StopVisiting      Default_Visitor_t;
158     typedef  DT_StackEntry<ArrayVectorView<Int32>::iterator>
159                     StackEntry_t;
160     typedef LabelType                       LabelT;
161 
162     //problem independent data.
163     Options_t                                   options_;
164     //problem dependent data members - is only set if
165     //a copy constructor, some sort of import
166     //function or the learn function is called
167     ArrayVector<DecisionTree_t>                 trees_;
168     ProblemSpec_t                               ext_param_;
169     /*mutable ArrayVector<int>                    tree_indices_;*/
170     rf::visitors::OnlineLearnVisitor            online_visitor_;
171 
172 
reset()173     void reset()
174     {
175         ext_param_.clear();
176         trees_.clear();
177     }
178 
179   public:
180 
181     /** \name Constructors
182      * Note: No copy constructor specified as no pointers are manipulated
183      * in this class
184 
185      * @{
186      */
187 
188      /**\brief default constructor
189      *
190      * \param options   general options to the Random Forest. Must be of Type
191      *                  Options_t
192      * \param ext_param problem specific values that can be supplied
193      *                  additionally. (class weights , labels etc)
194      * \sa  RandomForestOptions, ProblemSpec
195      *
196      */
RandomForest(Options_t const & options=Options_t (),ProblemSpec_t const & ext_param=ProblemSpec_t ())197     RandomForest(Options_t const & options = Options_t(),
198                  ProblemSpec_t const & ext_param = ProblemSpec_t())
199     :
200         options_(options),
201         ext_param_(ext_param)/*,
202         tree_indices_(options.tree_count_,0)*/
203     {
204         /*for(int ii = 0 ; ii < int(tree_indices_.size()); ++ii)
205             tree_indices_[ii] = ii;*/
206     }
207 
208     /**\brief Create RF from external source
209      * \param treeCount Number of trees to add.
210      * \param topology_begin
211      *                  Iterator to a Container where the topology_ data
212      *                  of the trees are stored.
213      *                  Iterator should support at least treeCount forward
214      *                  iterations. (i.e. topology_end - topology_begin >= treeCount
215      * \param parameter_begin
216      *                  iterator to a Container where the parameters_ data
217      *                  of the trees are stored. Iterator should support at
218      *                  least treeCount forward iterations.
219      * \param problem_spec
220      *                  Extrinsic parameters that specify the problem e.g.
221      *                  ClassCount, featureCount etc.
222      * \param options   (optional) specify options used to train the original
223      *                  Random forest. This parameter is not used anywhere
224      *                  during prediction and thus is optional.
225      *
226      */
227     template<class TopologyIterator, class ParameterIterator>
RandomForest(int treeCount,TopologyIterator topology_begin,ParameterIterator parameter_begin,ProblemSpec_t const & problem_spec,Options_t const & options=Options_t ())228     RandomForest(int                       treeCount,
229                   TopologyIterator         topology_begin,
230                   ParameterIterator        parameter_begin,
231                   ProblemSpec_t const & problem_spec,
232                   Options_t const &     options = Options_t())
233     :
234         trees_(treeCount, DecisionTree_t(problem_spec)),
235         ext_param_(problem_spec),
236         options_(options)
237     {
238          /* TODO: This constructor may be replaced by a Constructor using
239          * NodeProxy iterators to encapsulate the underlying data type.
240          */
241         for(int k=0; k<treeCount; ++k, ++topology_begin, ++parameter_begin)
242         {
243             trees_[k].topology_ = *topology_begin;
244             trees_[k].parameters_ = *parameter_begin;
245         }
246     }
247 
248     /** @} */
249 
250 
251     /** \name Data Access
252      * data access interface - usage of member variables is deprecated
253      *
254      * @{
255      */
256 
257     /**\brief return external parameters for viewing
258      * \return ProblemSpec_t
259      */
ext_param() const260     ProblemSpec_t const & ext_param() const
261     {
262         vigra_precondition(ext_param_.used() == true,
263            "RandomForest::ext_param(): "
264            "Random forest has not been trained yet.");
265         return ext_param_;
266     }
267 
268     /**\brief set external parameters
269      *
270      *  \param in external parameters to be set
271      *
272      * set external parameters explicitly.
273      * If Random Forest has not been trained the preprocessor will
274      * either ignore filling values set this way or will throw an exception
275      * if values specified manually do not match the value calculated
276      & during the preparation step.
277      */
set_ext_param(ProblemSpec_t const & in)278     void set_ext_param(ProblemSpec_t const & in)
279     {
280         ignore_argument(in);
281         vigra_precondition(ext_param_.used() == false,
282             "RandomForest::set_ext_param():"
283             "Random forest has been trained! Call reset()"
284             "before specifying new extrinsic parameters.");
285     }
286 
287     /**\brief access random forest options
288      *
289      * \return random forest options
290      */
set_options()291     Options_t & set_options()
292     {
293         return options_;
294     }
295 
296 
297     /**\brief access const random forest options
298      *
299      * \return const Option_t
300      */
options() const301     Options_t const & options() const
302     {
303         return options_;
304     }
305 
306     /**\brief access const trees
307      */
tree(int index) const308     DecisionTree_t const & tree(int index) const
309     {
310         return trees_[index];
311     }
312 
313     /**\brief access trees
314      */
tree(int index)315     DecisionTree_t & tree(int index)
316     {
317         return trees_[index];
318     }
319 
320     /**\brief return number of features used while
321      * training.
322      */
feature_count() const323     int feature_count() const
324     {
325       return ext_param_.column_count_;
326     }
327 
328 
329     /**\brief return number of features used while
330      * training.
331      *
332      * deprecated. Use feature_count() instead.
333      */
column_count() const334     int column_count() const
335     {
336       return ext_param_.column_count_;
337     }
338 
339     /**\brief return number of classes used while
340      * training.
341      */
class_count() const342     int class_count() const
343     {
344       return ext_param_.class_count_;
345     }
346 
347     /**\brief return number of trees
348      */
tree_count() const349     int tree_count() const
350     {
351       return options_.tree_count_;
352     }
353 
354     /** @} */
355 
356     /**\name Learning
357      * Following functions differ in the degree of customization
358      * allowed
359      *
360      * @{
361      */
362 
363     /**\brief learn on data with custom config and random number generator
364      *
365      * \param features  a N x M matrix containing N samples with M
366      *                  features
367      * \param response  a N x D matrix containing the corresponding
368      *                  response. Current split functors assume D to
369      *                  be 1 and ignore any additional columns.
370      *                  This is not enforced to allow future support
371      *                  for uncertain labels, label independent strata etc.
372      *                  The Preprocessor specified during construction
373      *                  should be able to handle features and labels
374      *                  features and the labels.
375      *                  see also: SplitFunctor, Preprocessing
376      *
377      * \param visitor   visitor which is to be applied after each split,
378      *                  tree and at the end. Use rf_default() for using
379      *                  default value. (No Visitors)
380      *                  see also: rf::visitors
381      * \param split     split functor to be used to calculate each split
382      *                  use rf_default() for using default value. (GiniSplit)
383      *                  see also:  rf::split
384      * \param stop
385      *                  predicate to be used to calculate each split
386      *                  use rf_default() for using default value. (EarlyStoppStd)
387      * \param random    RandomNumberGenerator to be used. Use
388      *                  rf_default() to use default value.(RandomMT19337)
389      *
390      *
391      */
392     template <class U, class C1,
393              class U2,class C2,
394              class Split_t,
395              class Stop_t,
396              class Visitor_t,
397              class Random_t>
398     void learn( MultiArrayView<2, U, C1> const  &   features,
399                 MultiArrayView<2, U2,C2> const  &   response,
400                 Visitor_t                           visitor,
401                 Split_t                             split,
402                 Stop_t                              stop,
403                 Random_t                 const  &   random);
404 
405     template <class U, class C1,
406              class U2,class C2,
407              class Split_t,
408              class Stop_t,
409              class Visitor_t>
learn(MultiArrayView<2,U,C1> const & features,MultiArrayView<2,U2,C2> const & response,Visitor_t visitor,Split_t split,Stop_t stop)410     void learn( MultiArrayView<2, U, C1> const  &   features,
411                 MultiArrayView<2, U2,C2> const  &   response,
412                 Visitor_t                           visitor,
413                 Split_t                             split,
414                 Stop_t                              stop)
415 
416     {
417         RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed);
418         learn(  features,
419                 response,
420                 visitor,
421                 split,
422                 stop,
423                 rnd);
424     }
425 
426     template <class U, class C1, class U2,class C2, class Visitor_t>
learn(MultiArrayView<2,U,C1> const & features,MultiArrayView<2,U2,C2> const & labels,Visitor_t visitor)427     void learn( MultiArrayView<2, U, C1> const  & features,
428                 MultiArrayView<2, U2,C2> const  & labels,
429                 Visitor_t                         visitor)
430     {
431         learn(  features,
432                 labels,
433                 visitor,
434                 rf_default(),
435                 rf_default());
436     }
437 
438     template <class U, class C1, class U2,class C2,
439               class Visitor_t, class Split_t>
learn(MultiArrayView<2,U,C1> const & features,MultiArrayView<2,U2,C2> const & labels,Visitor_t visitor,Split_t split)440     void learn(   MultiArrayView<2, U, C1> const  & features,
441                   MultiArrayView<2, U2,C2> const  & labels,
442                   Visitor_t                         visitor,
443                   Split_t                           split)
444     {
445         learn(  features,
446                 labels,
447                 visitor,
448                 split,
449                 rf_default());
450     }
451 
452     /**\brief learn on data with default configuration
453      *
454      * \param features  a N x M matrix containing N samples with M
455      *                  features
456      * \param labels    a N x D matrix containing the corresponding
457      *                  N labels. Current split functors assume D to
458      *                  be 1 and ignore any additional columns.
459      *                  this is not enforced to allow future support
460      *                  for uncertain labels.
461      *
462      * learning is done with:
463      *
464      * \sa rf::split, EarlyStoppStd
465      *
466      * - Randomly seeded random number generator
467      * - default gini split functor as described by Breiman
468      * - default The standard early stopping criterion
469      */
470     template <class U, class C1, class U2,class C2>
learn(MultiArrayView<2,U,C1> const & features,MultiArrayView<2,U2,C2> const & labels)471     void learn(   MultiArrayView<2, U, C1> const  & features,
472                     MultiArrayView<2, U2,C2> const  & labels)
473     {
474         learn(  features,
475                 labels,
476                 rf_default(),
477                 rf_default(),
478                 rf_default());
479     }
480 
481 
482     template<class U,class C1,
483         class U2, class C2,
484         class Split_t,
485         class Stop_t,
486         class Visitor_t,
487         class Random_t>
488     void onlineLearn(   MultiArrayView<2,U,C1> const & features,
489                         MultiArrayView<2,U2,C2> const & response,
490                         int new_start_index,
491                         Visitor_t visitor_,
492                         Split_t split_,
493                         Stop_t stop_,
494                         Random_t & random,
495                         bool adjust_thresholds=false);
496 
497     template <class U, class C1, class U2,class C2>
onlineLearn(MultiArrayView<2,U,C1> const & features,MultiArrayView<2,U2,C2> const & labels,int new_start_index,bool adjust_thresholds=false)498     void onlineLearn(   MultiArrayView<2, U, C1> const  & features,
499                         MultiArrayView<2, U2,C2> const  & labels,int new_start_index,bool adjust_thresholds=false)
500     {
501         RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed);
502         onlineLearn(features,
503                     labels,
504                     new_start_index,
505                     rf_default(),
506                     rf_default(),
507                     rf_default(),
508                     rnd,
509                     adjust_thresholds);
510     }
511 
512     template<class U,class C1,
513         class U2, class C2,
514         class Split_t,
515         class Stop_t,
516         class Visitor_t,
517         class Random_t>
518     void reLearnTree(MultiArrayView<2,U,C1> const & features,
519                      MultiArrayView<2,U2,C2> const & response,
520                      int treeId,
521                      Visitor_t visitor_,
522                      Split_t split_,
523                      Stop_t stop_,
524                      Random_t & random);
525 
526     template<class U, class C1, class U2, class C2>
reLearnTree(MultiArrayView<2,U,C1> const & features,MultiArrayView<2,U2,C2> const & labels,int treeId)527     void reLearnTree(MultiArrayView<2, U, C1> const & features,
528                      MultiArrayView<2, U2, C2> const & labels,
529                      int treeId)
530     {
531         RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed);
532         reLearnTree(features,
533                     labels,
534                     treeId,
535                     rf_default(),
536                     rf_default(),
537                     rf_default(),
538                     rnd);
539     }
540 
541     /** @} */
542 
543 
544 
545     /**\name Prediction
546      *
547      * @{
548      */
549 
550     /** \brief predict a label given a feature.
551      *
552      * \param features: a 1 by featureCount matrix containing
553      *        data point to be predicted (this only works in
554      *        classification setting)
555      * \param stop: early stopping criterion
556      * \return double value representing class. You can use the
557      *         predictLabels() function together with the
558      *         rf.external_parameter().class_type_ attribute
559      *         to get back the same type used during learning.
560      */
561     template <class U, class C, class Stop>
562     LabelType predictLabel(MultiArrayView<2, U, C>const & features, Stop & stop) const;
563 
564     template <class U, class C>
predictLabel(MultiArrayView<2,U,C> const & features)565     LabelType predictLabel(MultiArrayView<2, U, C>const & features)
566     {
567         return predictLabel(features, rf_default());
568     }
569     /** \brief predict a label with features and class priors
570      *
571      * \param features: same as above.
572      * \param prior:   iterator to prior weighting of classes
573      * \return sam as above.
574      */
575     template <class U, class C>
576     LabelType predictLabel(MultiArrayView<2, U, C> const & features,
577                                 ArrayVectorView<double> prior) const;
578 
579     /** \brief predict multiple labels with given features
580      *
581      * \param features: a n by featureCount matrix containing
582      *        data point to be predicted (this only works in
583      *        classification setting)
584      * \param labels: a n by 1 matrix passed by reference to store
585      *        output.
586      *
587      * If the input contains an NaN value, an precondition exception is thrown.
588      */
589     template <class U, class C1, class T, class C2>
predictLabels(MultiArrayView<2,U,C1> const & features,MultiArrayView<2,T,C2> & labels) const590     void predictLabels(MultiArrayView<2, U, C1>const & features,
591                        MultiArrayView<2, T, C2> & labels) const
592     {
593         vigra_precondition(features.shape(0) == labels.shape(0),
594             "RandomForest::predictLabels(): Label array has wrong size.");
595         for(int k=0; k<features.shape(0); ++k)
596         {
597             vigra_precondition(!detail::contains_nan(rowVector(features, k)),
598                 "RandomForest::predictLabels(): NaN in feature matrix.");
599             labels(k,0) = detail::RequiresExplicitCast<T>::cast(predictLabel(rowVector(features, k), rf_default()));
600         }
601     }
602 
603     /** \brief predict multiple labels with given features
604      *
605      * \param features: a n by featureCount matrix containing
606      *        data point to be predicted (this only works in
607      *        classification setting)
608      * \param labels: a n by 1 matrix passed by reference to store
609      *        output.
610      * \param nanLabel: label to be returned for the row of the input that
611      *        contain an NaN value.
612      */
613     template <class U, class C1, class T, class C2>
predictLabels(MultiArrayView<2,U,C1> const & features,MultiArrayView<2,T,C2> & labels,LabelType nanLabel) const614     void predictLabels(MultiArrayView<2, U, C1>const & features,
615                        MultiArrayView<2, T, C2> & labels,
616                        LabelType nanLabel) const
617     {
618         vigra_precondition(features.shape(0) == labels.shape(0),
619             "RandomForest::predictLabels(): Label array has wrong size.");
620         for(int k=0; k<features.shape(0); ++k)
621         {
622             if(detail::contains_nan(rowVector(features, k)))
623                 labels(k,0) = nanLabel;
624             else
625                 labels(k,0) = detail::RequiresExplicitCast<T>::cast(predictLabel(rowVector(features, k), rf_default()));
626         }
627     }
628 
629     /** \brief predict multiple labels with given features
630      *
631      * \param features: a n by featureCount matrix containing
632      *        data point to be predicted (this only works in
633      *        classification setting)
634      * \param labels: a n by 1 matrix passed by reference to store
635      *        output.
636      * \param stop: an early stopping criterion.
637      */
638     template <class U, class C1, class T, class C2, class Stop>
predictLabels(MultiArrayView<2,U,C1> const & features,MultiArrayView<2,T,C2> & labels,Stop & stop) const639     void predictLabels(MultiArrayView<2, U, C1>const & features,
640                        MultiArrayView<2, T, C2> & labels,
641                        Stop                     & stop) const
642     {
643         vigra_precondition(features.shape(0) == labels.shape(0),
644             "RandomForest::predictLabels(): Label array has wrong size.");
645         for(int k=0; k<features.shape(0); ++k)
646             labels(k,0) = detail::RequiresExplicitCast<T>::cast(predictLabel(rowVector(features, k), stop));
647     }
648     /** \brief predict the class probabilities for multiple labels
649      *
650      *  \param features same as above
651      *  \param prob a n x class_count_ matrix. passed by reference to
652      *  save class probabilities
653      *  \param stop earlystopping criterion
654      *  \sa EarlyStopping
655 
656         When a row of the feature array contains an NaN, the corresponding instance
657         cannot belong to any of the classes. The corresponding row in the probability
658         array will therefore contain all zeros.
659      */
660     template <class U, class C1, class T, class C2, class Stop>
661     void predictProbabilities(MultiArrayView<2, U, C1>const &   features,
662                               MultiArrayView<2, T, C2> &        prob,
663                               Stop                     &        stop) const;
664     template <class T1,class T2, class C>
665     void predictProbabilities(OnlinePredictionSet<T1> &  predictionSet,
666                                MultiArrayView<2, T2, C> &       prob);
667 
668     /** \brief predict the class probabilities for multiple labels
669      *
670      *  \param features same as above
671      *  \param prob a n x class_count_ matrix. passed by reference to
672      *  save class probabilities
673      */
674     template <class U, class C1, class T, class C2>
predictProbabilities(MultiArrayView<2,U,C1> const & features,MultiArrayView<2,T,C2> & prob) const675     void predictProbabilities(MultiArrayView<2, U, C1>const &   features,
676                               MultiArrayView<2, T, C2> &        prob)  const
677     {
678         predictProbabilities(features, prob, rf_default());
679     }
680 
681     template <class U, class C1, class T, class C2>
682     void predictRaw(MultiArrayView<2, U, C1>const &   features,
683                     MultiArrayView<2, T, C2> &        prob)  const;
684 
685 
686     /** @} */
687 
688 };
689 
690 
691 template <class LabelType, class PreprocessorTag>
692 template<class U,class C1,
693     class U2, class C2,
694     class Split_t,
695     class Stop_t,
696     class Visitor_t,
697     class Random_t>
onlineLearn(MultiArrayView<2,U,C1> const & features,MultiArrayView<2,U2,C2> const & response,int new_start_index,Visitor_t visitor_,Split_t split_,Stop_t stop_,Random_t & random,bool adjust_thresholds)698 void RandomForest<LabelType, PreprocessorTag>::onlineLearn(MultiArrayView<2,U,C1> const & features,
699                                                              MultiArrayView<2,U2,C2> const & response,
700                                                              int new_start_index,
701                                                              Visitor_t visitor_,
702                                                              Split_t split_,
703                                                              Stop_t stop_,
704                                                              Random_t & random,
705                                                              bool adjust_thresholds)
706 {
707     online_visitor_.activate();
708     online_visitor_.adjust_thresholds=adjust_thresholds;
709 
710     using namespace rf;
711     //typedefs
712     typedef Processor<PreprocessorTag,LabelType,U,C1,U2,C2> Preprocessor_t;
713     typedef          UniformIntRandomFunctor<Random_t>
714                                                     RandFunctor_t;
715     // default values and initialization
716     // Value Chooser chooses second argument as value if first argument
717     // is of type RF_DEFAULT. (thanks to template magic - don't care about
718     // it - just smile and wave.
719 
720     #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
721     Default_Stop_t default_stop(options_);
722     typename RF_CHOOSER(Stop_t)::type stop
723             = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
724     Default_Split_t default_split;
725     typename RF_CHOOSER(Split_t)::type split
726             = RF_CHOOSER(Split_t)::choose(split_, default_split);
727     rf::visitors::StopVisiting stopvisiting;
728     typedef  rf::visitors::detail::VisitorNode
729                 <rf::visitors::OnlineLearnVisitor,
730                  typename RF_CHOOSER(Visitor_t)::type>
731                                                         IntermedVis;
732     IntermedVis
733         visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
734     #undef RF_CHOOSER
735     vigra_precondition(options_.prepare_online_learning_,"onlineLearn: online learning must be enabled on RandomForest construction");
736 
737     // Preprocess the data to get something the split functor can work
738     // with. Also fill the ext_param structure by preprocessing
739     // option parameters that could only be completely evaluated
740     // when the training data is known.
741     ext_param_.class_count_=0;
742     Preprocessor_t preprocessor(    features, response,
743                                     options_, ext_param_);
744 
745     // Make stl compatible random functor.
746     RandFunctor_t           randint     ( random);
747 
748     // Give the Split functor information about the data.
749     split.set_external_parameters(ext_param_);
750     stop.set_external_parameters(ext_param_);
751 
752 
753     //Create poisson samples
754     PoissonSampler<RandomTT800> poisson_sampler(1.0,vigra::Int32(new_start_index),vigra::Int32(ext_param().row_count_));
755 
756     //TODO: visitors for online learning
757     //visitor.visit_at_beginning(*this, preprocessor);
758 
759     // THE MAIN EFFING RF LOOP - YEAY DUDE!
760     for(int ii = 0; ii < static_cast<int>(trees_.size()); ++ii)
761     {
762         online_visitor_.tree_id=ii;
763         poisson_sampler.sample();
764         std::map<int,int> leaf_parents;
765         leaf_parents.clear();
766         //Get all the leaf nodes for that sample
767         for(int s=0;s<poisson_sampler.numOfSamples();++s)
768         {
769             int sample=poisson_sampler[s];
770             online_visitor_.current_label=preprocessor.response()(sample,0);
771             online_visitor_.last_node_id=StackEntry_t::DecisionTreeNoParent;
772             int leaf=trees_[ii].getToLeaf(rowVector(features,sample),online_visitor_);
773 
774 
775             //Add to the list for that leaf
776             online_visitor_.add_to_index_list(ii,leaf,sample);
777             //TODO: Class count?
778             //Store parent
779             if(Node<e_ConstProbNode>(trees_[ii].topology_,trees_[ii].parameters_,leaf).prob_begin()[preprocessor.response()(sample,0)]!=1.0)
780             {
781                 leaf_parents[leaf]=online_visitor_.last_node_id;
782             }
783         }
784 
785 
786         std::map<int,int>::iterator leaf_iterator;
787         for(leaf_iterator=leaf_parents.begin();leaf_iterator!=leaf_parents.end();++leaf_iterator)
788         {
789             int leaf=leaf_iterator->first;
790             int parent=leaf_iterator->second;
791             int lin_index=online_visitor_.trees_online_information[ii].exterior_to_index[leaf];
792             ArrayVector<Int32> indeces;
793             indeces.clear();
794             indeces.swap(online_visitor_.trees_online_information[ii].index_lists[lin_index]);
795             StackEntry_t stack_entry(indeces.begin(),
796                                      indeces.end(),
797                                      ext_param_.class_count_);
798 
799 
800             if(parent!=-1)
801             {
802                 if(NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(0)==leaf)
803                 {
804                     stack_entry.leftParent=parent;
805                 }
806                 else
807                 {
808                     vigra_assert(NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(1)==leaf,"last_node_id seems to be wrong");
809                     stack_entry.rightParent=parent;
810                 }
811             }
812             //trees_[ii].continueLearn(preprocessor.features(),preprocessor.response(),stack_entry,split,stop,visitor,randint,leaf);
813             trees_[ii].continueLearn(preprocessor.features(),preprocessor.response(),stack_entry,split,stop,visitor,randint,-1);
814             //Now, the last one moved onto leaf
815             online_visitor_.move_exterior_node(ii,trees_[ii].topology_.size(),ii,leaf);
816             //Now it should be classified correctly!
817         }
818 
819         /*visitor
820             .visit_after_tree(  *this,
821                                 preprocessor,
822                                 poisson_sampler,
823                                 stack_entry,
824                                 ii);*/
825     }
826 
827     //visitor.visit_at_end(*this, preprocessor);
828     online_visitor_.deactivate();
829 }
830 
831 template<class LabelType, class PreprocessorTag>
832 template<class U,class C1,
833     class U2, class C2,
834     class Split_t,
835     class Stop_t,
836     class Visitor_t,
837     class Random_t>
reLearnTree(MultiArrayView<2,U,C1> const & features,MultiArrayView<2,U2,C2> const & response,int treeId,Visitor_t visitor_,Split_t split_,Stop_t stop_,Random_t & random)838 void RandomForest<LabelType, PreprocessorTag>::reLearnTree(MultiArrayView<2,U,C1> const & features,
839                  MultiArrayView<2,U2,C2> const & response,
840                  int treeId,
841                  Visitor_t visitor_,
842                  Split_t split_,
843                  Stop_t stop_,
844                  Random_t & random)
845 {
846     using namespace rf;
847 
848 
849     typedef          UniformIntRandomFunctor<Random_t>
850                                                     RandFunctor_t;
851 
852     // See rf_preprocessing.hxx for more info on this
853     ext_param_.class_count_=0;
854     typedef Processor<PreprocessorTag,LabelType, U, C1, U2, C2> Preprocessor_t;
855 
856     // default values and initialization
857     // Value Chooser chooses second argument as value if first argument
858     // is of type RF_DEFAULT. (thanks to template magic - don't care about
859     // it - just smile and wave.
860 
861     #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
862     Default_Stop_t default_stop(options_);
863     typename RF_CHOOSER(Stop_t)::type stop
864             = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
865     Default_Split_t default_split;
866     typename RF_CHOOSER(Split_t)::type split
867             = RF_CHOOSER(Split_t)::choose(split_, default_split);
868     rf::visitors::StopVisiting stopvisiting;
869     typedef  rf::visitors::detail::VisitorNode
870                 <rf::visitors::OnlineLearnVisitor,
871                 typename RF_CHOOSER(Visitor_t)::type> IntermedVis;
872     IntermedVis
873         visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
874     #undef RF_CHOOSER
875     vigra_precondition(options_.prepare_online_learning_,"reLearnTree: Re learning trees only makes sense, if online learning is enabled");
876     online_visitor_.activate();
877 
878     // Make stl compatible random functor.
879     RandFunctor_t           randint     ( random);
880 
881     // Preprocess the data to get something the split functor can work
882     // with. Also fill the ext_param structure by preprocessing
883     // option parameters that could only be completely evaluated
884     // when the training data is known.
885     Preprocessor_t preprocessor(    features, response,
886                                     options_, ext_param_);
887 
888     // Give the Split functor information about the data.
889     split.set_external_parameters(ext_param_);
890     stop.set_external_parameters(ext_param_);
891 
892     /**\todo    replace this crappy class out. It uses function pointers.
893      *          and is making code slower according to me.
894      *          Comment from Nathan: This is copied from Rahul, so me=Rahul
895      */
896     Sampler<Random_t > sampler(preprocessor.strata().begin(),
897                                preprocessor.strata().end(),
898                                detail::make_sampler_opt(options_)
899                                         .sampleSize(ext_param().actual_msample_),
900                                &random);
901     //initialize First region/node/stack entry
902     sampler
903         .sample();
904 
905     StackEntry_t
906         first_stack_entry(  sampler.sampledIndices().begin(),
907                             sampler.sampledIndices().end(),
908                             ext_param_.class_count_);
909     first_stack_entry
910         .set_oob_range(     sampler.oobIndices().begin(),
911                             sampler.oobIndices().end());
912     online_visitor_.reset_tree(treeId);
913     online_visitor_.tree_id=treeId;
914     trees_[treeId].reset();
915     trees_[treeId]
916         .learn( preprocessor.features(),
917                 preprocessor.response(),
918                 first_stack_entry,
919                 split,
920                 stop,
921                 visitor,
922                 randint);
923     visitor
924         .visit_after_tree(  *this,
925                             preprocessor,
926                             sampler,
927                             first_stack_entry,
928                             treeId);
929 
930     online_visitor_.deactivate();
931 }
932 
933 template <class LabelType, class PreprocessorTag>
934 template <class U, class C1,
935          class U2,class C2,
936          class Split_t,
937          class Stop_t,
938          class Visitor_t,
939          class Random_t>
940 void RandomForest<LabelType, PreprocessorTag>::
learn(MultiArrayView<2,U,C1> const & features,MultiArrayView<2,U2,C2> const & response,Visitor_t visitor_,Split_t split_,Stop_t stop_,Random_t const & random)941                      learn( MultiArrayView<2, U, C1> const  &   features,
942                             MultiArrayView<2, U2,C2> const  &   response,
943                             Visitor_t                           visitor_,
944                             Split_t                             split_,
945                             Stop_t                              stop_,
946                             Random_t                 const  &   random)
947 {
948     using namespace rf;
949     //this->reset();
950     //typedefs
951     typedef          UniformIntRandomFunctor<Random_t>
952                                                     RandFunctor_t;
953 
954     // See rf_preprocessing.hxx for more info on this
955     typedef Processor<PreprocessorTag,LabelType, U, C1, U2, C2> Preprocessor_t;
956 
957     vigra_precondition(features.shape(0) == response.shape(0),
958         "RandomForest::learn(): shape mismatch between features and response.");
959 
960     // default values and initialization
961     // Value Chooser chooses second argument as value if first argument
962     // is of type RF_DEFAULT. (thanks to template magic - don't care about
963     // it - just smile and wave).
964 
965     #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
966     Default_Stop_t default_stop(options_);
967     typename RF_CHOOSER(Stop_t)::type stop
968             = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
969     Default_Split_t default_split;
970     typename RF_CHOOSER(Split_t)::type split
971             = RF_CHOOSER(Split_t)::choose(split_, default_split);
972     rf::visitors::StopVisiting stopvisiting;
973     typedef  rf::visitors::detail::VisitorNode<
974                 rf::visitors::OnlineLearnVisitor,
975                 typename RF_CHOOSER(Visitor_t)::type> IntermedVis;
976     IntermedVis
977         visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
978     #undef RF_CHOOSER
979     if(options_.prepare_online_learning_)
980         online_visitor_.activate();
981     else
982         online_visitor_.deactivate();
983 
984 
985     // Make stl compatible random functor.
986     RandFunctor_t           randint     ( random);
987 
988 
989     // Preprocess the data to get something the split functor can work
990     // with. Also fill the ext_param structure by preprocessing
991     // option parameters that could only be completely evaluated
992     // when the training data is known.
993     Preprocessor_t preprocessor(    features, response,
994                                     options_, ext_param_);
995 
996     // Give the Split functor information about the data.
997     split.set_external_parameters(ext_param_);
998     stop.set_external_parameters(ext_param_);
999 
1000 
1001     //initialize trees.
1002     trees_.resize(options_.tree_count_  , DecisionTree_t(ext_param_));
1003 
1004     Sampler<Random_t > sampler(preprocessor.strata().begin(),
1005                                preprocessor.strata().end(),
1006                                detail::make_sampler_opt(options_)
1007                                         .sampleSize(ext_param().actual_msample_),
1008                                &random);
1009 
1010     visitor.visit_at_beginning(*this, preprocessor);
1011     // THE MAIN EFFING RF LOOP - YEAY DUDE!
1012 
1013     for(int ii = 0; ii < static_cast<int>(trees_.size()); ++ii)
1014     {
1015         //initialize First region/node/stack entry
1016         sampler
1017             .sample();
1018         StackEntry_t
1019             first_stack_entry(  sampler.sampledIndices().begin(),
1020                                 sampler.sampledIndices().end(),
1021                                 ext_param_.class_count_);
1022         first_stack_entry
1023             .set_oob_range(     sampler.oobIndices().begin(),
1024                                 sampler.oobIndices().end());
1025         trees_[ii]
1026             .learn(             preprocessor.features(),
1027                                 preprocessor.response(),
1028                                 first_stack_entry,
1029                                 split,
1030                                 stop,
1031                                 visitor,
1032                                 randint);
1033         visitor
1034             .visit_after_tree(  *this,
1035                                 preprocessor,
1036                                 sampler,
1037                                 first_stack_entry,
1038                                 ii);
1039     }
1040 
1041     visitor.visit_at_end(*this, preprocessor);
1042     // Only for online learning?
1043     online_visitor_.deactivate();
1044 }
1045 
1046 
1047 
1048 
1049 template <class LabelType, class Tag>
1050 template <class U, class C, class Stop>
1051 LabelType RandomForest<LabelType, Tag>
predictLabel(MultiArrayView<2,U,C> const & features,Stop & stop) const1052     ::predictLabel(MultiArrayView<2, U, C> const & features, Stop & stop) const
1053 {
1054     vigra_precondition(columnCount(features) >= ext_param_.column_count_,
1055         "RandomForestn::predictLabel():"
1056             " Too few columns in feature matrix.");
1057     vigra_precondition(rowCount(features) == 1,
1058         "RandomForestn::predictLabel():"
1059             " Feature matrix must have a singlerow.");
1060     MultiArray<2, double> probabilities(Shape2(1, ext_param_.class_count_), 0.0);
1061     LabelType          d;
1062     predictProbabilities(features, probabilities, stop);
1063     ext_param_.to_classlabel(argMax(probabilities), d);
1064     return d;
1065 }
1066 
1067 
1068 //Same thing as above with priors for each label !!!
1069 template <class LabelType, class PreprocessorTag>
1070 template <class U, class C>
1071 LabelType RandomForest<LabelType, PreprocessorTag>
predictLabel(MultiArrayView<2,U,C> const & features,ArrayVectorView<double> priors) const1072     ::predictLabel( MultiArrayView<2, U, C> const & features,
1073                     ArrayVectorView<double> priors) const
1074 {
1075     using namespace functor;
1076     vigra_precondition(columnCount(features) >= ext_param_.column_count_,
1077         "RandomForestn::predictLabel(): Too few columns in feature matrix.");
1078     vigra_precondition(rowCount(features) == 1,
1079         "RandomForestn::predictLabel():"
1080         " Feature matrix must have a single row.");
1081     Matrix<double>  prob(1,ext_param_.class_count_);
1082     predictProbabilities(features, prob);
1083     std::transform( prob.begin(), prob.end(),
1084                     priors.begin(), prob.begin(),
1085                     Arg1()*Arg2());
1086     LabelType          d;
1087     ext_param_.to_classlabel(argMax(prob), d);
1088     return d;
1089 }
1090 
1091 template<class LabelType,class PreprocessorTag>
1092 template <class T1,class T2, class C>
1093 void RandomForest<LabelType,PreprocessorTag>
predictProbabilities(OnlinePredictionSet<T1> & predictionSet,MultiArrayView<2,T2,C> & prob)1094     ::predictProbabilities(OnlinePredictionSet<T1> &  predictionSet,
1095                           MultiArrayView<2, T2, C> &       prob)
1096 {
1097     //Features are n xp
1098     //prob is n x NumOfLabel probability for each feature in each class
1099 
1100     vigra_precondition(rowCount(predictionSet.features) == rowCount(prob),
1101                        "RandomFroest::predictProbabilities():"
1102                        " Feature matrix and probability matrix size mismatch.");
1103     // num of features must be bigger than num of features in Random forest training
1104     // but why bigger?
1105     vigra_precondition( columnCount(predictionSet.features) >= ext_param_.column_count_,
1106       "RandomForestn::predictProbabilities():"
1107         " Too few columns in feature matrix.");
1108     vigra_precondition( columnCount(prob)
1109                         == static_cast<MultiArrayIndex>(ext_param_.class_count_),
1110       "RandomForestn::predictProbabilities():"
1111       " Probability matrix must have as many columns as there are classes.");
1112     prob.init(0.0);
1113     //store total weights
1114     std::vector<T1> totalWeights(predictionSet.indices[0].size(),0.0);
1115     //Go through all trees
1116     int set_id=-1;
1117     for(int k=0; k<options_.tree_count_; ++k)
1118     {
1119         set_id=(set_id+1) % predictionSet.indices[0].size();
1120         typedef std::set<SampleRange<T1> > my_set;
1121         typedef typename my_set::iterator set_it;
1122         //typedef std::set<std::pair<int,SampleRange<T1> > >::iterator set_it;
1123         //Build a stack with all the ranges we have
1124         std::vector<std::pair<int,set_it> > stack;
1125         stack.clear();
1126         for(set_it i=predictionSet.ranges[set_id].begin();
1127              i!=predictionSet.ranges[set_id].end();++i)
1128             stack.push_back(std::pair<int,set_it>(2,i));
1129         //get weights predicted by single tree
1130         int num_decisions=0;
1131         while(!stack.empty())
1132         {
1133             set_it range=stack.back().second;
1134             int index=stack.back().first;
1135             stack.pop_back();
1136             ++num_decisions;
1137 
1138             if(trees_[k].isLeafNode(trees_[k].topology_[index]))
1139             {
1140                 ArrayVector<double>::iterator weights=Node<e_ConstProbNode>(trees_[k].topology_,
1141                                                                             trees_[k].parameters_,
1142                                                                             index).prob_begin();
1143                 for(int i=range->start;i!=range->end;++i)
1144                 {
1145                     //update votecount.
1146                     for(int l=0; l<ext_param_.class_count_; ++l)
1147                     {
1148                         prob(predictionSet.indices[set_id][i], l) += static_cast<T2>(weights[l]);
1149                         //every weight in totalWeight.
1150                         totalWeights[predictionSet.indices[set_id][i]] += static_cast<T1>(weights[l]);
1151                     }
1152                 }
1153             }
1154 
1155             else
1156             {
1157                 if(trees_[k].topology_[index]!=i_ThresholdNode)
1158                 {
1159                     throw std::runtime_error("predicting with online prediction sets is only supported for RFs with threshold nodes");
1160                 }
1161                 Node<i_ThresholdNode> node(trees_[k].topology_,trees_[k].parameters_,index);
1162                 if(range->min_boundaries[node.column()]>=node.threshold())
1163                 {
1164                     //Everything goes to right child
1165                     stack.push_back(std::pair<int,set_it>(node.child(1),range));
1166                     continue;
1167                 }
1168                 if(range->max_boundaries[node.column()]<node.threshold())
1169                 {
1170                     //Everything goes to the left child
1171                     stack.push_back(std::pair<int,set_it>(node.child(0),range));
1172                     continue;
1173                 }
1174                 //We have to split at this node
1175                 SampleRange<T1> new_range=*range;
1176                 new_range.min_boundaries[node.column()]=FLT_MAX;
1177                 range->max_boundaries[node.column()]=-FLT_MAX;
1178                 new_range.start=new_range.end=range->end;
1179                 int i=range->start;
1180                 while(i!=range->end)
1181                 {
1182                     //Decide for range->indices[i]
1183                     if(predictionSet.features(predictionSet.indices[set_id][i],node.column())>=node.threshold())
1184                     {
1185                         new_range.min_boundaries[node.column()]=std::min(new_range.min_boundaries[node.column()],
1186                                                                     predictionSet.features(predictionSet.indices[set_id][i],node.column()));
1187                         --range->end;
1188                         --new_range.start;
1189                         std::swap(predictionSet.indices[set_id][i],predictionSet.indices[set_id][range->end]);
1190 
1191                     }
1192                     else
1193                     {
1194                         range->max_boundaries[node.column()]=std::max(range->max_boundaries[node.column()],
1195                                                                  predictionSet.features(predictionSet.indices[set_id][i],node.column()));
1196                         ++i;
1197                     }
1198                 }
1199                 //The old one ...
1200                 if(range->start==range->end)
1201                 {
1202                     predictionSet.ranges[set_id].erase(range);
1203                 }
1204                 else
1205                 {
1206                     stack.push_back(std::pair<int,set_it>(node.child(0),range));
1207                 }
1208                 //And the new one ...
1209                 if(new_range.start!=new_range.end)
1210                 {
1211                     std::pair<set_it,bool> new_it=predictionSet.ranges[set_id].insert(new_range);
1212                     stack.push_back(std::pair<int,set_it>(node.child(1),new_it.first));
1213                 }
1214             }
1215         }
1216         predictionSet.cumulativePredTime[k]=num_decisions;
1217     }
1218     for(unsigned int i=0;i<totalWeights.size();++i)
1219     {
1220         double test=0.0;
1221         //Normalise votes in each row by total VoteCount (totalWeight
1222         for(int l=0; l<ext_param_.class_count_; ++l)
1223         {
1224             test+=prob(i,l);
1225             prob(i, l) /= totalWeights[i];
1226         }
1227         assert(test==totalWeights[i]);
1228         assert(totalWeights[i]>0.0);
1229     }
1230 }
1231 
1232 template <class LabelType, class PreprocessorTag>
1233 template <class U, class C1, class T, class C2, class Stop_t>
1234 void RandomForest<LabelType, PreprocessorTag>
predictProbabilities(MultiArrayView<2,U,C1> const & features,MultiArrayView<2,T,C2> & prob,Stop_t & stop_) const1235     ::predictProbabilities(MultiArrayView<2, U, C1>const &  features,
1236                            MultiArrayView<2, T, C2> &       prob,
1237                            Stop_t                   &       stop_) const
1238 {
1239     //Features are n xp
1240     //prob is n x NumOfLabel probability for each feature in each class
1241 
1242     vigra_precondition(rowCount(features) == rowCount(prob),
1243       "RandomForestn::predictProbabilities():"
1244         " Feature matrix and probability matrix size mismatch.");
1245 
1246     // num of features must be bigger than num of features in Random forest training
1247     // but why bigger?
1248     vigra_precondition( columnCount(features) >= ext_param_.column_count_,
1249       "RandomForestn::predictProbabilities():"
1250         " Too few columns in feature matrix.");
1251     vigra_precondition( columnCount(prob)
1252                         == static_cast<MultiArrayIndex>(ext_param_.class_count_),
1253       "RandomForestn::predictProbabilities():"
1254       " Probability matrix must have as many columns as there are classes.");
1255 
1256     #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
1257     Default_Stop_t default_stop(options_);
1258     typename RF_CHOOSER(Stop_t)::type & stop
1259             = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
1260     #undef RF_CHOOSER
1261     stop.set_external_parameters(ext_param_, tree_count());
1262     prob.init(NumericTraits<T>::zero());
1263     /* This code was originally there for testing early stopping
1264      * - we wanted the order of the trees to be randomized
1265     if(tree_indices_.size() != 0)
1266     {
1267        std::random_shuffle(tree_indices_.begin(),
1268                            tree_indices_.end());
1269     }
1270     */
1271     //Classify for each row.
1272     for(int row=0; row < rowCount(features); ++row)
1273     {
1274         MultiArrayView<2, U, StridedArrayTag> currentRow(rowVector(features, row));
1275 
1276         // when the features contain an NaN, the instance doesn't belong to any class
1277         // => indicate this by returning a zero probability array.
1278         if(detail::contains_nan(currentRow))
1279         {
1280             rowVector(prob, row).init(0.0);
1281             continue;
1282         }
1283 
1284         ArrayVector<double>::const_iterator weights;
1285 
1286         //totalWeight == totalVoteCount!
1287         double totalWeight = 0.0;
1288 
1289         //Let each tree classify...
1290         for(int k=0; k<options_.tree_count_; ++k)
1291         {
1292             //get weights predicted by single tree
1293             weights = trees_[k /*tree_indices_[k]*/].predict(currentRow);
1294 
1295             //update votecount.
1296             int weighted = options_.predict_weighted_;
1297             for(int l=0; l<ext_param_.class_count_; ++l)
1298             {
1299                 double cur_w = weights[l] * (weighted * (*(weights-1))
1300                                            + (1-weighted));
1301                 prob(row, l) += static_cast<T>(cur_w);
1302                 //every weight in totalWeight.
1303                 totalWeight += cur_w;
1304             }
1305             if(stop.after_prediction(weights,
1306                                      k,
1307                                      rowVector(prob, row),
1308                                      totalWeight))
1309             {
1310                 break;
1311             }
1312         }
1313 
1314         //Normalise votes in each row by total VoteCount (totalWeight
1315         for(int l=0; l< ext_param_.class_count_; ++l)
1316         {
1317             prob(row, l) /= detail::RequiresExplicitCast<T>::cast(totalWeight);
1318         }
1319     }
1320 
1321 }
1322 
1323 template <class LabelType, class PreprocessorTag>
1324 template <class U, class C1, class T, class C2>
1325 void RandomForest<LabelType, PreprocessorTag>
predictRaw(MultiArrayView<2,U,C1> const & features,MultiArrayView<2,T,C2> & prob) const1326     ::predictRaw(MultiArrayView<2, U, C1>const &  features,
1327                            MultiArrayView<2, T, C2> &       prob) const
1328 {
1329     //Features are n xp
1330     //prob is n x NumOfLabel probability for each feature in each class
1331 
1332     vigra_precondition(rowCount(features) == rowCount(prob),
1333       "RandomForestn::predictProbabilities():"
1334         " Feature matrix and probability matrix size mismatch.");
1335 
1336     // num of features must be bigger than num of features in Random forest training
1337     // but why bigger?
1338     vigra_precondition( columnCount(features) >= ext_param_.column_count_,
1339       "RandomForestn::predictProbabilities():"
1340         " Too few columns in feature matrix.");
1341     vigra_precondition( columnCount(prob)
1342                         == static_cast<MultiArrayIndex>(ext_param_.class_count_),
1343       "RandomForestn::predictProbabilities():"
1344       " Probability matrix must have as many columns as there are classes.");
1345 
1346     #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
1347     prob.init(NumericTraits<T>::zero());
1348     /* This code was originally there for testing early stopping
1349      * - we wanted the order of the trees to be randomized
1350     if(tree_indices_.size() != 0)
1351     {
1352        std::random_shuffle(tree_indices_.begin(),
1353                            tree_indices_.end());
1354     }
1355     */
1356     //Classify for each row.
1357     for(int row=0; row < rowCount(features); ++row)
1358     {
1359         ArrayVector<double>::const_iterator weights;
1360 
1361         //totalWeight == totalVoteCount!
1362         double totalWeight = 0.0;
1363 
1364         //Let each tree classify...
1365         for(int k=0; k<options_.tree_count_; ++k)
1366         {
1367             //get weights predicted by single tree
1368             weights = trees_[k /*tree_indices_[k]*/].predict(rowVector(features, row));
1369 
1370             //update votecount.
1371             int weighted = options_.predict_weighted_;
1372             for(int l=0; l<ext_param_.class_count_; ++l)
1373             {
1374                 double cur_w = weights[l] * (weighted * (*(weights-1))
1375                                            + (1-weighted));
1376                 prob(row, l) += static_cast<T>(cur_w);
1377                 //every weight in totalWeight.
1378                 totalWeight += cur_w;
1379             }
1380         }
1381     }
1382     prob/= options_.tree_count_;
1383 
1384 }
1385 
1386 } // namespace vigra
1387 
1388 #include "random_forest/rf_algorithm.hxx"
1389 #endif // VIGRA_RANDOM_FOREST_HXX
1390