1 /*
2  * Software License Agreement (BSD License)
3  *
4  *  Point Cloud Library (PCL) - www.pointclouds.org
5  *  Copyright (c) 2010-2011, Willow Garage, Inc.
6  *
7  *  All rights reserved.
8  *
9  *  Redistribution and use in source and binary forms, with or without
10  *  modification, are permitted provided that the following conditions
11  *  are met:
12  *
13  *   * Redistributions of source code must retain the above copyright
14  *     notice, this list of conditions and the following disclaimer.
15  *   * Redistributions in binary form must reproduce the above
16  *     copyright notice, this list of conditions and the following
17  *     disclaimer in the documentation and/or other materials provided
18  *     with the distribution.
19  *   * Neither the name of Willow Garage, Inc. nor the names of its
20  *     contributors may be used to endorse or promote products derived
21  *     from this software without specific prior written permission.
22  *
23  *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
24  *  "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
25  *  LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
26  *  FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
27  *  COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
28  *  INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
29  *  BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
30  *  LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
31  *  CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
32  *  LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
33  *  ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
34  *  POSSIBILITY OF SUCH DAMAGE.
35  *
36  */
37 
38 #pragma once
39 
40 #include <pcl/common/common.h>
41 #include <pcl/ml/dt/decision_forest.h>
42 #include <pcl/ml/dt/decision_tree.h>
43 #include <pcl/ml/dt/decision_tree_trainer.h>
44 #include <pcl/ml/feature_handler.h>
45 #include <pcl/ml/stats_estimator.h>
46 
47 #include <vector>
48 
49 namespace pcl {
50 
51 /** Trainer for decision trees. */
52 template <class FeatureType,
53           class DataSet,
54           class LabelType,
55           class ExampleIndex,
56           class NodeType>
57 class PCL_EXPORTS DecisionForestTrainer {
58 
59 public:
60   /** Constructor. */
61   DecisionForestTrainer();
62 
63   /** Destructor. */
64   virtual ~DecisionForestTrainer();
65 
66   /** Sets the number of trees to train.
67    *
68    * \param[in] num_of_trees the number of trees
69    */
70   inline void
setNumberOfTreesToTrain(const std::size_t num_of_trees)71   setNumberOfTreesToTrain(const std::size_t num_of_trees)
72   {
73     num_of_trees_to_train_ = num_of_trees;
74   }
75 
76   /** Sets the feature handler used to create and evaluate features.
77    *
78    * \param[in] feature_handler the feature handler
79    */
80   inline void
setFeatureHandler(pcl::FeatureHandler<FeatureType,DataSet,ExampleIndex> & feature_handler)81   setFeatureHandler(
82       pcl::FeatureHandler<FeatureType, DataSet, ExampleIndex>& feature_handler)
83   {
84     decision_tree_trainer_.setFeatureHandler(feature_handler);
85   }
86 
87   /** Sets the object for estimating the statistics for tree nodes.
88    *
89    * \param[in] stats_estimator the statistics estimator
90    */
91   inline void
setStatsEstimator(pcl::StatsEstimator<LabelType,NodeType,DataSet,ExampleIndex> & stats_estimator)92   setStatsEstimator(
93       pcl::StatsEstimator<LabelType, NodeType, DataSet, ExampleIndex>& stats_estimator)
94   {
95     decision_tree_trainer_.setStatsEstimator(stats_estimator);
96   }
97 
98   /** Sets the maximum depth of the learned tree.
99    *
100    * \param[in] max_tree_depth maximum depth of the learned tree
101    */
102   inline void
setMaxTreeDepth(const std::size_t max_tree_depth)103   setMaxTreeDepth(const std::size_t max_tree_depth)
104   {
105     decision_tree_trainer_.setMaxTreeDepth(max_tree_depth);
106   }
107 
108   /** Sets the number of features used to find optimal decision features.
109    *
110    * \param[in] num_of_features the number of features
111    */
112   inline void
setNumOfFeatures(const std::size_t num_of_features)113   setNumOfFeatures(const std::size_t num_of_features)
114   {
115     decision_tree_trainer_.setNumOfFeatures(num_of_features);
116   }
117 
118   /** Sets the number of thresholds tested for finding the optimal decision threshold on
119    *  the feature responses.
120    *
121    * \param[in] num_of_threshold the number of thresholds
122    */
123   inline void
setNumOfThresholds(const std::size_t num_of_threshold)124   setNumOfThresholds(const std::size_t num_of_threshold)
125   {
126     decision_tree_trainer_.setNumOfThresholds(num_of_threshold);
127   }
128 
129   /** Sets the input data set used for training.
130    *
131    * \param[in] data_set the data set used for training
132    */
133   inline void
setTrainingDataSet(DataSet & data_set)134   setTrainingDataSet(DataSet& data_set)
135   {
136     decision_tree_trainer_.setTrainingDataSet(data_set);
137   }
138 
139   /** Example indices that specify the data used for training.
140    *
141    * \param[in] examples the examples
142    */
143   inline void
setExamples(std::vector<ExampleIndex> & examples)144   setExamples(std::vector<ExampleIndex>& examples)
145   {
146     decision_tree_trainer_.setExamples(examples);
147   }
148 
149   /** Sets the label data corresponding to the example data.
150    *
151    * \param[in] label_data the label data
152    */
153   inline void
setLabelData(std::vector<LabelType> & label_data)154   setLabelData(std::vector<LabelType>& label_data)
155   {
156     decision_tree_trainer_.setLabelData(label_data);
157   }
158 
159   /** Sets the minimum number of examples to continue growing a tree.
160    *
161    * \param[in] n number of examples
162    */
163   inline void
setMinExamplesForSplit(std::size_t n)164   setMinExamplesForSplit(std::size_t n)
165   {
166     decision_tree_trainer_.setMinExamplesForSplit(n);
167   }
168 
169   /** Specify the thresholds to be used when evaluating features.
170    *
171    * \param[in] thres the threshold values
172    */
173   void
setThresholds(std::vector<float> & thres)174   setThresholds(std::vector<float>& thres)
175   {
176     decision_tree_trainer_.setThresholds(thres);
177   }
178 
179   /** Specify the data provider.
180    *
181    * \param[in] dtdp the data provider that should implement getDatasetAndLabels()
182    *            function
183    */
184   void
setDecisionTreeDataProvider(typename pcl::DecisionTreeTrainerDataProvider<FeatureType,DataSet,LabelType,ExampleIndex,NodeType>::Ptr & dtdp)185   setDecisionTreeDataProvider(
186       typename pcl::DecisionTreeTrainerDataProvider<FeatureType,
187                                                     DataSet,
188                                                     LabelType,
189                                                     ExampleIndex,
190                                                     NodeType>::Ptr& dtdp)
191   {
192     decision_tree_trainer_.setDecisionTreeDataProvider(dtdp);
193   }
194 
195   /** Specify if the features are randomly generated at each split node.
196    *
197    * \param[in] b do it or not
198    */
199   void
setRandomFeaturesAtSplitNode(bool b)200   setRandomFeaturesAtSplitNode(bool b)
201   {
202     decision_tree_trainer_.setRandomFeaturesAtSplitNode(b);
203   }
204 
205   /** Trains a decision forest using the set training data and settings.
206    *
207    * \param[out] forest destination for the trained forest
208    */
209   void
210   train(DecisionForest<NodeType>& forest);
211 
212 private:
213   /** The number of trees to train. */
214   std::size_t num_of_trees_to_train_;
215 
216   /** The trainer for the decision trees of the forest. */
217   pcl::DecisionTreeTrainer<FeatureType, DataSet, LabelType, ExampleIndex, NodeType>
218       decision_tree_trainer_;
219 };
220 
221 } // namespace pcl
222 
223 #include <pcl/ml/impl/dt/decision_forest_trainer.hpp>
224