1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 // By downloading, copying, installing or using the software you agree to this license.
6 // If you do not agree to this license, do not download, install,
7 // copy or use the software.
8 //
9 //
10 // License Agreement
11 // For Open Source Computer Vision Library
12 //
13 // Copyright (C) 2000, Intel Corporation, all rights reserved.
14 // Copyright (C) 2014, Itseez Inc, all rights reserved.
15 // Third party copyrights are property of their respective owners.
16 //
17 // Redistribution and use in source and binary forms, with or without modification,
18 // are permitted provided that the following conditions are met:
19 //
20 // * Redistribution's of source code must retain the above copyright notice,
21 // this list of conditions and the following disclaimer.
22 //
23 // * Redistribution's in binary form must reproduce the above copyright notice,
24 // this list of conditions and the following disclaimer in the documentation
25 // and/or other materials provided with the distribution.
26 //
27 // * The name of the copyright holders may not be used to endorse or promote products
28 // derived from this software without specific prior written permission.
29 //
30 // This software is provided by the copyright holders and contributors "as is" and
31 // any express or implied warranties, including, but not limited to, the implied
32 // warranties of merchantability and fitness for a particular purpose are disclaimed.
33 // In no event shall the Intel Corporation or contributors be liable for any direct,
34 // indirect, incidental, special, exemplary, or consequential damages
35 // (including, but not limited to, procurement of substitute goods or services;
36 // loss of use, data, or profits; or business interruption) however caused
37 // and on any theory of liability, whether in contract, strict liability,
38 // or tort (including negligence or otherwise) arising in any way out of
39 // the use of this software, even if advised of the possibility of such damage.
40 //
41 //M*/
42
43 #include "precomp.hpp"
44 namespace cv {
45 namespace ml {
46
47 //////////////////////////////////////////////////////////////////////////////////////////
48 // Random trees //
49 //////////////////////////////////////////////////////////////////////////////////////////
RTreeParams()50 RTreeParams::RTreeParams()
51 {
52 CV_TRACE_FUNCTION();
53 calcVarImportance = false;
54 nactiveVars = 0;
55 termCrit = TermCriteria(TermCriteria::EPS + TermCriteria::COUNT, 50, 0.1);
56 }
57
RTreeParams(bool _calcVarImportance,int _nactiveVars,TermCriteria _termCrit)58 RTreeParams::RTreeParams(bool _calcVarImportance,
59 int _nactiveVars,
60 TermCriteria _termCrit )
61 {
62 CV_TRACE_FUNCTION();
63 calcVarImportance = _calcVarImportance;
64 nactiveVars = _nactiveVars;
65 termCrit = _termCrit;
66 }
67
68
69 class DTreesImplForRTrees CV_FINAL : public DTreesImpl
70 {
71 public:
DTreesImplForRTrees()72 DTreesImplForRTrees()
73 {
74 CV_TRACE_FUNCTION();
75 params.setMaxDepth(5);
76 params.setMinSampleCount(10);
77 params.setRegressionAccuracy(0.f);
78 params.useSurrogates = false;
79 params.setMaxCategories(10);
80 params.setCVFolds(0);
81 params.use1SERule = false;
82 params.truncatePrunedTree = false;
83 params.priors = Mat();
84 oobError = 0;
85 }
~DTreesImplForRTrees()86 virtual ~DTreesImplForRTrees() {}
87
clear()88 void clear() CV_OVERRIDE
89 {
90 CV_TRACE_FUNCTION();
91 DTreesImpl::clear();
92 oobError = 0.;
93 }
94
getActiveVars()95 const vector<int>& getActiveVars() CV_OVERRIDE
96 {
97 CV_TRACE_FUNCTION();
98 RNG &rng = theRNG();
99 int i, nvars = (int)allVars.size(), m = (int)activeVars.size();
100 for( i = 0; i < nvars; i++ )
101 {
102 int i1 = rng.uniform(0, nvars);
103 int i2 = rng.uniform(0, nvars);
104 std::swap(allVars[i1], allVars[i2]);
105 }
106 for( i = 0; i < m; i++ )
107 activeVars[i] = allVars[i];
108 return activeVars;
109 }
110
startTraining(const Ptr<TrainData> & trainData,int flags)111 void startTraining( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE
112 {
113 CV_TRACE_FUNCTION();
114 CV_Assert(!trainData.empty());
115 DTreesImpl::startTraining(trainData, flags);
116 int nvars = w->data->getNVars();
117 int i, m = rparams.nactiveVars > 0 ? rparams.nactiveVars : cvRound(std::sqrt((double)nvars));
118 m = std::min(std::max(m, 1), nvars);
119 allVars.resize(nvars);
120 activeVars.resize(m);
121 for( i = 0; i < nvars; i++ )
122 allVars[i] = varIdx[i];
123 }
124
endTraining()125 void endTraining() CV_OVERRIDE
126 {
127 CV_TRACE_FUNCTION();
128 DTreesImpl::endTraining();
129 vector<int> a, b;
130 std::swap(allVars, a);
131 std::swap(activeVars, b);
132 }
133
train(const Ptr<TrainData> & trainData,int flags)134 bool train( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE
135 {
136 CV_TRACE_FUNCTION();
137 RNG &rng = theRNG();
138 CV_Assert(!trainData.empty());
139 startTraining(trainData, flags);
140 int treeidx, ntrees = (rparams.termCrit.type & TermCriteria::COUNT) != 0 ?
141 rparams.termCrit.maxCount : 10000;
142 int i, j, k, vi, vi_, n = (int)w->sidx.size();
143 int nclasses = (int)classLabels.size();
144 double eps = (rparams.termCrit.type & TermCriteria::EPS) != 0 &&
145 rparams.termCrit.epsilon > 0 ? rparams.termCrit.epsilon : 0.;
146 vector<int> sidx(n);
147 vector<uchar> oobmask(n);
148 vector<int> oobidx;
149 vector<int> oobperm;
150 vector<double> oobres(n, 0.);
151 vector<int> oobcount(n, 0);
152 vector<int> oobvotes(n*nclasses, 0);
153 int nvars = w->data->getNVars();
154 int nallvars = w->data->getNAllVars();
155 const int* vidx = !varIdx.empty() ? &varIdx[0] : 0;
156 vector<float> samplebuf(nallvars);
157 Mat samples = w->data->getSamples();
158 float* psamples = samples.ptr<float>();
159 size_t sstep0 = samples.step1(), sstep1 = 1;
160 Mat sample0, sample(nallvars, 1, CV_32F, &samplebuf[0]);
161 int predictFlags = _isClassifier ? (PREDICT_MAX_VOTE + RAW_OUTPUT) : PREDICT_SUM;
162
163 bool calcOOBError = eps > 0 || rparams.calcVarImportance;
164 double max_response = 0.;
165
166 if( w->data->getLayout() == COL_SAMPLE )
167 std::swap(sstep0, sstep1);
168
169 if( !_isClassifier )
170 {
171 for( i = 0; i < n; i++ )
172 {
173 double val = std::abs(w->ord_responses[w->sidx[i]]);
174 max_response = std::max(max_response, val);
175 }
176 CV_Assert(fabs(max_response) > 0);
177 }
178
179 if( rparams.calcVarImportance )
180 varImportance.resize(nallvars, 0.f);
181
182 for( treeidx = 0; treeidx < ntrees; treeidx++ )
183 {
184 for( i = 0; i < n; i++ )
185 oobmask[i] = (uchar)1;
186
187 for( i = 0; i < n; i++ )
188 {
189 j = rng.uniform(0, n);
190 sidx[i] = w->sidx[j];
191 oobmask[j] = (uchar)0;
192 }
193 int root = addTree( sidx );
194 if( root < 0 )
195 return false;
196
197 if( calcOOBError )
198 {
199 oobidx.clear();
200 for( i = 0; i < n; i++ )
201 {
202 if( oobmask[i] )
203 oobidx.push_back(i);
204 }
205 int n_oob = (int)oobidx.size();
206 // if there is no out-of-bag samples, we can not compute OOB error
207 // nor update the variable importance vector; so we proceed to the next tree
208 if( n_oob == 0 )
209 continue;
210 double ncorrect_responses = 0.;
211
212 oobError = 0.;
213 for( i = 0; i < n_oob; i++ )
214 {
215 j = oobidx[i];
216 sample = Mat( nallvars, 1, CV_32F, psamples + sstep0*w->sidx[j], sstep1*sizeof(psamples[0]) );
217
218 double val = predictTrees(Range(treeidx, treeidx+1), sample, predictFlags);
219 double sample_weight = w->sample_weights[w->sidx[j]];
220 if( !_isClassifier )
221 {
222 oobres[j] += val;
223 oobcount[j]++;
224 double true_val = w->ord_responses[w->sidx[j]];
225 double a = oobres[j]/oobcount[j] - true_val;
226 oobError += sample_weight * a*a;
227 val = (val - true_val)/max_response;
228 ncorrect_responses += std::exp( -val*val );
229 }
230 else
231 {
232 int ival = cvRound(val);
233 //Voting scheme to combine OOB errors of each tree
234 int* votes = &oobvotes[j*nclasses];
235 votes[ival]++;
236 int best_class = 0;
237 for( k = 1; k < nclasses; k++ )
238 if( votes[best_class] < votes[k] )
239 best_class = k;
240 int diff = best_class != w->cat_responses[w->sidx[j]];
241 oobError += sample_weight * diff;
242 ncorrect_responses += diff == 0;
243 }
244 }
245
246 oobError /= n_oob;
247 if( rparams.calcVarImportance && n_oob > 1 )
248 {
249 Mat sample_clone;
250 oobperm.resize(n_oob);
251 for( i = 0; i < n_oob; i++ )
252 oobperm[i] = oobidx[i];
253 for (i = n_oob - 1; i > 0; --i) //Randomly shuffle indices so we can permute features
254 {
255 int r_i = rng.uniform(0, n_oob);
256 std::swap(oobperm[i], oobperm[r_i]);
257 }
258
259 for( vi_ = 0; vi_ < nvars; vi_++ )
260 {
261 vi = vidx ? vidx[vi_] : vi_; //Ensure that only the user specified predictors are used for training
262 double ncorrect_responses_permuted = 0;
263
264 for( i = 0; i < n_oob; i++ )
265 {
266 j = oobidx[i];
267 int vj = oobperm[i];
268 sample0 = Mat( nallvars, 1, CV_32F, psamples + sstep0*w->sidx[j], sstep1*sizeof(psamples[0]) );
269 sample0.copyTo(sample_clone); //create a copy so we don't mess up the original data
270 sample_clone.at<float>(vi) = psamples[sstep0*w->sidx[vj] + sstep1*vi];
271
272 double val = predictTrees(Range(treeidx, treeidx+1), sample_clone, predictFlags);
273 if( !_isClassifier )
274 {
275 val = (val - w->ord_responses[w->sidx[j]])/max_response;
276 ncorrect_responses_permuted += exp( -val*val );
277 }
278 else
279 {
280 ncorrect_responses_permuted += cvRound(val) == w->cat_responses[w->sidx[j]];
281 }
282 }
283 varImportance[vi] += (float)(ncorrect_responses - ncorrect_responses_permuted);
284 }
285 }
286 }
287 if( calcOOBError && oobError < eps )
288 break;
289 }
290
291 if( rparams.calcVarImportance )
292 {
293 for( vi_ = 0; vi_ < nallvars; vi_++ )
294 varImportance[vi_] = std::max(varImportance[vi_], 0.f);
295 normalize(varImportance, varImportance, 1., 0, NORM_L1);
296 }
297 endTraining();
298 return true;
299 }
300
writeTrainingParams(FileStorage & fs) const301 void writeTrainingParams( FileStorage& fs ) const CV_OVERRIDE
302 {
303 CV_TRACE_FUNCTION();
304 DTreesImpl::writeTrainingParams(fs);
305 fs << "nactive_vars" << rparams.nactiveVars;
306 }
307
write(FileStorage & fs) const308 void write( FileStorage& fs ) const CV_OVERRIDE
309 {
310 CV_TRACE_FUNCTION();
311 if( roots.empty() )
312 CV_Error( CV_StsBadArg, "RTrees have not been trained" );
313
314 writeFormat(fs);
315 writeParams(fs);
316
317 fs << "oob_error" << oobError;
318 if( !varImportance.empty() )
319 fs << "var_importance" << varImportance;
320
321 int k, ntrees = (int)roots.size();
322
323 fs << "ntrees" << ntrees
324 << "trees" << "[";
325
326 for( k = 0; k < ntrees; k++ )
327 {
328 fs << "{";
329 writeTree(fs, roots[k]);
330 fs << "}";
331 }
332
333 fs << "]";
334 }
335
readParams(const FileNode & fn)336 void readParams( const FileNode& fn ) CV_OVERRIDE
337 {
338 CV_TRACE_FUNCTION();
339 DTreesImpl::readParams(fn);
340
341 FileNode tparams_node = fn["training_params"];
342 rparams.nactiveVars = (int)tparams_node["nactive_vars"];
343 }
344
read(const FileNode & fn)345 void read( const FileNode& fn ) CV_OVERRIDE
346 {
347 CV_TRACE_FUNCTION();
348 clear();
349
350 //int nclasses = (int)fn["nclasses"];
351 //int nsamples = (int)fn["nsamples"];
352 oobError = (double)fn["oob_error"];
353 int ntrees = (int)fn["ntrees"];
354
355 readVectorOrMat(fn["var_importance"], varImportance);
356
357 readParams(fn);
358
359 FileNode trees_node = fn["trees"];
360 FileNodeIterator it = trees_node.begin();
361 CV_Assert( ntrees == (int)trees_node.size() );
362
363 for( int treeidx = 0; treeidx < ntrees; treeidx++, ++it )
364 {
365 FileNode nfn = (*it)["nodes"];
366 readTree(nfn);
367 }
368 }
369
getVotes(InputArray input,OutputArray output,int flags) const370 void getVotes( InputArray input, OutputArray output, int flags ) const
371 {
372 CV_TRACE_FUNCTION();
373 CV_Assert( !roots.empty() );
374 int nclasses = (int)classLabels.size(), ntrees = (int)roots.size();
375 Mat samples = input.getMat(), results;
376 int i, j, nsamples = samples.rows;
377
378 int predictType = flags & PREDICT_MASK;
379 if( predictType == PREDICT_AUTO )
380 {
381 predictType = !_isClassifier || (classLabels.size() == 2 && (flags & RAW_OUTPUT) != 0) ?
382 PREDICT_SUM : PREDICT_MAX_VOTE;
383 }
384
385 if( predictType == PREDICT_SUM )
386 {
387 output.create(nsamples, ntrees, CV_32F);
388 results = output.getMat();
389 for( i = 0; i < nsamples; i++ )
390 {
391 for( j = 0; j < ntrees; j++ )
392 {
393 float val = predictTrees( Range(j, j+1), samples.row(i), flags);
394 results.at<float> (i, j) = val;
395 }
396 }
397 } else
398 {
399 vector<int> votes;
400 output.create(nsamples+1, nclasses, CV_32S);
401 results = output.getMat();
402
403 for ( j = 0; j < nclasses; j++)
404 {
405 results.at<int> (0, j) = classLabels[j];
406 }
407
408 for( i = 0; i < nsamples; i++ )
409 {
410 votes.clear();
411 for( j = 0; j < ntrees; j++ )
412 {
413 int val = (int)predictTrees( Range(j, j+1), samples.row(i), flags);
414 votes.push_back(val);
415 }
416
417 for ( j = 0; j < nclasses; j++)
418 {
419 results.at<int> (i+1, j) = (int)std::count(votes.begin(), votes.end(), classLabels[j]);
420 }
421 }
422 }
423 }
424
getOOBError() const425 double getOOBError() const {
426 return oobError;
427 }
428
429 RTreeParams rparams;
430 double oobError;
431 vector<float> varImportance;
432 vector<int> allVars, activeVars;
433 };
434
435
436 class RTreesImpl CV_FINAL : public RTrees
437 {
438 public:
getCalculateVarImportance() const439 inline bool getCalculateVarImportance() const CV_OVERRIDE { return impl.rparams.calcVarImportance; }
setCalculateVarImportance(bool val)440 inline void setCalculateVarImportance(bool val) CV_OVERRIDE { impl.rparams.calcVarImportance = val; }
getActiveVarCount() const441 inline int getActiveVarCount() const CV_OVERRIDE { return impl.rparams.nactiveVars; }
setActiveVarCount(int val)442 inline void setActiveVarCount(int val) CV_OVERRIDE { impl.rparams.nactiveVars = val; }
getTermCriteria() const443 inline TermCriteria getTermCriteria() const CV_OVERRIDE { return impl.rparams.termCrit; }
setTermCriteria(const TermCriteria & val)444 inline void setTermCriteria(const TermCriteria& val) CV_OVERRIDE { impl.rparams.termCrit = val; }
445
getMaxCategories() const446 inline int getMaxCategories() const CV_OVERRIDE { return impl.params.getMaxCategories(); }
setMaxCategories(int val)447 inline void setMaxCategories(int val) CV_OVERRIDE { impl.params.setMaxCategories(val); }
getMaxDepth() const448 inline int getMaxDepth() const CV_OVERRIDE { return impl.params.getMaxDepth(); }
setMaxDepth(int val)449 inline void setMaxDepth(int val) CV_OVERRIDE { impl.params.setMaxDepth(val); }
getMinSampleCount() const450 inline int getMinSampleCount() const CV_OVERRIDE { return impl.params.getMinSampleCount(); }
setMinSampleCount(int val)451 inline void setMinSampleCount(int val) CV_OVERRIDE { impl.params.setMinSampleCount(val); }
getCVFolds() const452 inline int getCVFolds() const CV_OVERRIDE { return impl.params.getCVFolds(); }
setCVFolds(int val)453 inline void setCVFolds(int val) CV_OVERRIDE { impl.params.setCVFolds(val); }
getUseSurrogates() const454 inline bool getUseSurrogates() const CV_OVERRIDE { return impl.params.getUseSurrogates(); }
setUseSurrogates(bool val)455 inline void setUseSurrogates(bool val) CV_OVERRIDE { impl.params.setUseSurrogates(val); }
getUse1SERule() const456 inline bool getUse1SERule() const CV_OVERRIDE { return impl.params.getUse1SERule(); }
setUse1SERule(bool val)457 inline void setUse1SERule(bool val) CV_OVERRIDE { impl.params.setUse1SERule(val); }
getTruncatePrunedTree() const458 inline bool getTruncatePrunedTree() const CV_OVERRIDE { return impl.params.getTruncatePrunedTree(); }
setTruncatePrunedTree(bool val)459 inline void setTruncatePrunedTree(bool val) CV_OVERRIDE { impl.params.setTruncatePrunedTree(val); }
getRegressionAccuracy() const460 inline float getRegressionAccuracy() const CV_OVERRIDE { return impl.params.getRegressionAccuracy(); }
setRegressionAccuracy(float val)461 inline void setRegressionAccuracy(float val) CV_OVERRIDE { impl.params.setRegressionAccuracy(val); }
getPriors() const462 inline cv::Mat getPriors() const CV_OVERRIDE { return impl.params.getPriors(); }
setPriors(const cv::Mat & val)463 inline void setPriors(const cv::Mat& val) CV_OVERRIDE { impl.params.setPriors(val); }
getVotes(InputArray input,OutputArray output,int flags) const464 inline void getVotes(InputArray input, OutputArray output, int flags) const CV_OVERRIDE {return impl.getVotes(input,output,flags);}
465
RTreesImpl()466 RTreesImpl() {}
~RTreesImpl()467 virtual ~RTreesImpl() CV_OVERRIDE {}
468
getDefaultName() const469 String getDefaultName() const CV_OVERRIDE { return "opencv_ml_rtrees"; }
470
train(const Ptr<TrainData> & trainData,int flags)471 bool train( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE
472 {
473 CV_TRACE_FUNCTION();
474 CV_Assert(!trainData.empty());
475 if (impl.getCVFolds() != 0)
476 CV_Error(Error::StsBadArg, "Cross validation for RTrees is not implemented");
477 return impl.train(trainData, flags);
478 }
479
predict(InputArray samples,OutputArray results,int flags) const480 float predict( InputArray samples, OutputArray results, int flags ) const CV_OVERRIDE
481 {
482 CV_TRACE_FUNCTION();
483 CV_CheckEQ(samples.cols(), getVarCount(), "");
484 return impl.predict(samples, results, flags);
485 }
486
write(FileStorage & fs) const487 void write( FileStorage& fs ) const CV_OVERRIDE
488 {
489 CV_TRACE_FUNCTION();
490 impl.write(fs);
491 }
492
read(const FileNode & fn)493 void read( const FileNode& fn ) CV_OVERRIDE
494 {
495 CV_TRACE_FUNCTION();
496 impl.read(fn);
497 }
498
getVarImportance() const499 Mat getVarImportance() const CV_OVERRIDE { return Mat_<float>(impl.varImportance, true); }
getVarCount() const500 int getVarCount() const CV_OVERRIDE { return impl.getVarCount(); }
501
isTrained() const502 bool isTrained() const CV_OVERRIDE { return impl.isTrained(); }
isClassifier() const503 bool isClassifier() const CV_OVERRIDE { return impl.isClassifier(); }
504
getRoots() const505 const vector<int>& getRoots() const CV_OVERRIDE { return impl.getRoots(); }
getNodes() const506 const vector<Node>& getNodes() const CV_OVERRIDE { return impl.getNodes(); }
getSplits() const507 const vector<Split>& getSplits() const CV_OVERRIDE { return impl.getSplits(); }
getSubsets() const508 const vector<int>& getSubsets() const CV_OVERRIDE { return impl.getSubsets(); }
getOOBError() const509 double getOOBError() const CV_OVERRIDE { return impl.getOOBError(); }
510
511
512 DTreesImplForRTrees impl;
513 };
514
515
create()516 Ptr<RTrees> RTrees::create()
517 {
518 CV_TRACE_FUNCTION();
519 return makePtr<RTreesImpl>();
520 }
521
522 //Function needed for Python and Java wrappers
load(const String & filepath,const String & nodeName)523 Ptr<RTrees> RTrees::load(const String& filepath, const String& nodeName)
524 {
525 CV_TRACE_FUNCTION();
526 return Algorithm::load<RTrees>(filepath, nodeName);
527 }
528
529 }}
530
531 // End of file.
532