1 /* 2 * Software License Agreement (BSD License) 3 * 4 * Point Cloud Library (PCL) - www.pointclouds.org 5 * Copyright (c) 2010-2011, Willow Garage, Inc. 6 * 7 * All rights reserved. 8 * 9 * Redistribution and use in source and binary forms, with or without 10 * modification, are permitted provided that the following conditions 11 * are met: 12 * 13 * * Redistributions of source code must retain the above copyright 14 * notice, this list of conditions and the following disclaimer. 15 * * Redistributions in binary form must reproduce the above 16 * copyright notice, this list of conditions and the following 17 * disclaimer in the documentation and/or other materials provided 18 * with the distribution. 19 * * Neither the name of Willow Garage, Inc. nor the names of its 20 * contributors may be used to endorse or promote products derived 21 * from this software without specific prior written permission. 22 * 23 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 24 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 25 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS 26 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE 27 * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 28 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, 29 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 30 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 31 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 32 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN 33 * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 34 * POSSIBILITY OF SUCH DAMAGE. 35 * 36 */ 37 38 #pragma once 39 40 #include <pcl/common/common.h> 41 #include <pcl/ml/dt/decision_forest.h> 42 #include <pcl/ml/dt/decision_tree.h> 43 #include <pcl/ml/dt/decision_tree_trainer.h> 44 #include <pcl/ml/feature_handler.h> 45 #include <pcl/ml/stats_estimator.h> 46 47 #include <vector> 48 49 namespace pcl { 50 51 /** Trainer for decision trees. */ 52 template <class FeatureType, 53 class DataSet, 54 class LabelType, 55 class ExampleIndex, 56 class NodeType> 57 class PCL_EXPORTS DecisionForestTrainer { 58 59 public: 60 /** Constructor. */ 61 DecisionForestTrainer(); 62 63 /** Destructor. */ 64 virtual ~DecisionForestTrainer(); 65 66 /** Sets the number of trees to train. 67 * 68 * \param[in] num_of_trees the number of trees 69 */ 70 inline void setNumberOfTreesToTrain(const std::size_t num_of_trees)71 setNumberOfTreesToTrain(const std::size_t num_of_trees) 72 { 73 num_of_trees_to_train_ = num_of_trees; 74 } 75 76 /** Sets the feature handler used to create and evaluate features. 77 * 78 * \param[in] feature_handler the feature handler 79 */ 80 inline void setFeatureHandler(pcl::FeatureHandler<FeatureType,DataSet,ExampleIndex> & feature_handler)81 setFeatureHandler( 82 pcl::FeatureHandler<FeatureType, DataSet, ExampleIndex>& feature_handler) 83 { 84 decision_tree_trainer_.setFeatureHandler(feature_handler); 85 } 86 87 /** Sets the object for estimating the statistics for tree nodes. 88 * 89 * \param[in] stats_estimator the statistics estimator 90 */ 91 inline void setStatsEstimator(pcl::StatsEstimator<LabelType,NodeType,DataSet,ExampleIndex> & stats_estimator)92 setStatsEstimator( 93 pcl::StatsEstimator<LabelType, NodeType, DataSet, ExampleIndex>& stats_estimator) 94 { 95 decision_tree_trainer_.setStatsEstimator(stats_estimator); 96 } 97 98 /** Sets the maximum depth of the learned tree. 99 * 100 * \param[in] max_tree_depth maximum depth of the learned tree 101 */ 102 inline void setMaxTreeDepth(const std::size_t max_tree_depth)103 setMaxTreeDepth(const std::size_t max_tree_depth) 104 { 105 decision_tree_trainer_.setMaxTreeDepth(max_tree_depth); 106 } 107 108 /** Sets the number of features used to find optimal decision features. 109 * 110 * \param[in] num_of_features the number of features 111 */ 112 inline void setNumOfFeatures(const std::size_t num_of_features)113 setNumOfFeatures(const std::size_t num_of_features) 114 { 115 decision_tree_trainer_.setNumOfFeatures(num_of_features); 116 } 117 118 /** Sets the number of thresholds tested for finding the optimal decision threshold on 119 * the feature responses. 120 * 121 * \param[in] num_of_threshold the number of thresholds 122 */ 123 inline void setNumOfThresholds(const std::size_t num_of_threshold)124 setNumOfThresholds(const std::size_t num_of_threshold) 125 { 126 decision_tree_trainer_.setNumOfThresholds(num_of_threshold); 127 } 128 129 /** Sets the input data set used for training. 130 * 131 * \param[in] data_set the data set used for training 132 */ 133 inline void setTrainingDataSet(DataSet & data_set)134 setTrainingDataSet(DataSet& data_set) 135 { 136 decision_tree_trainer_.setTrainingDataSet(data_set); 137 } 138 139 /** Example indices that specify the data used for training. 140 * 141 * \param[in] examples the examples 142 */ 143 inline void setExamples(std::vector<ExampleIndex> & examples)144 setExamples(std::vector<ExampleIndex>& examples) 145 { 146 decision_tree_trainer_.setExamples(examples); 147 } 148 149 /** Sets the label data corresponding to the example data. 150 * 151 * \param[in] label_data the label data 152 */ 153 inline void setLabelData(std::vector<LabelType> & label_data)154 setLabelData(std::vector<LabelType>& label_data) 155 { 156 decision_tree_trainer_.setLabelData(label_data); 157 } 158 159 /** Sets the minimum number of examples to continue growing a tree. 160 * 161 * \param[in] n number of examples 162 */ 163 inline void setMinExamplesForSplit(std::size_t n)164 setMinExamplesForSplit(std::size_t n) 165 { 166 decision_tree_trainer_.setMinExamplesForSplit(n); 167 } 168 169 /** Specify the thresholds to be used when evaluating features. 170 * 171 * \param[in] thres the threshold values 172 */ 173 void setThresholds(std::vector<float> & thres)174 setThresholds(std::vector<float>& thres) 175 { 176 decision_tree_trainer_.setThresholds(thres); 177 } 178 179 /** Specify the data provider. 180 * 181 * \param[in] dtdp the data provider that should implement getDatasetAndLabels() 182 * function 183 */ 184 void setDecisionTreeDataProvider(typename pcl::DecisionTreeTrainerDataProvider<FeatureType,DataSet,LabelType,ExampleIndex,NodeType>::Ptr & dtdp)185 setDecisionTreeDataProvider( 186 typename pcl::DecisionTreeTrainerDataProvider<FeatureType, 187 DataSet, 188 LabelType, 189 ExampleIndex, 190 NodeType>::Ptr& dtdp) 191 { 192 decision_tree_trainer_.setDecisionTreeDataProvider(dtdp); 193 } 194 195 /** Specify if the features are randomly generated at each split node. 196 * 197 * \param[in] b do it or not 198 */ 199 void setRandomFeaturesAtSplitNode(bool b)200 setRandomFeaturesAtSplitNode(bool b) 201 { 202 decision_tree_trainer_.setRandomFeaturesAtSplitNode(b); 203 } 204 205 /** Trains a decision forest using the set training data and settings. 206 * 207 * \param[out] forest destination for the trained forest 208 */ 209 void 210 train(DecisionForest<NodeType>& forest); 211 212 private: 213 /** The number of trees to train. */ 214 std::size_t num_of_trees_to_train_; 215 216 /** The trainer for the decision trees of the forest. */ 217 pcl::DecisionTreeTrainer<FeatureType, DataSet, LabelType, ExampleIndex, NodeType> 218 decision_tree_trainer_; 219 }; 220 221 } // namespace pcl 222 223 #include <pcl/ml/impl/dt/decision_forest_trainer.hpp> 224