1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                           License Agreement
11 //                For Open Source Computer Vision Library
12 //
13 // Copyright (C) 2000, Intel Corporation, all rights reserved.
14 // Copyright (C) 2014, Itseez Inc, all rights reserved.
15 // Third party copyrights are property of their respective owners.
16 //
17 // Redistribution and use in source and binary forms, with or without modification,
18 // are permitted provided that the following conditions are met:
19 //
20 //   * Redistribution's of source code must retain the above copyright notice,
21 //     this list of conditions and the following disclaimer.
22 //
23 //   * Redistribution's in binary form must reproduce the above copyright notice,
24 //     this list of conditions and the following disclaimer in the documentation
25 //     and/or other materials provided with the distribution.
26 //
27 //   * The name of the copyright holders may not be used to endorse or promote products
28 //     derived from this software without specific prior written permission.
29 //
30 // This software is provided by the copyright holders and contributors "as is" and
31 // any express or implied warranties, including, but not limited to, the implied
32 // warranties of merchantability and fitness for a particular purpose are disclaimed.
33 // In no event shall the Intel Corporation or contributors be liable for any direct,
34 // indirect, incidental, special, exemplary, or consequential damages
35 // (including, but not limited to, procurement of substitute goods or services;
36 // loss of use, data, or profits; or business interruption) however caused
37 // and on any theory of liability, whether in contract, strict liability,
38 // or tort (including negligence or otherwise) arising in any way out of
39 // the use of this software, even if advised of the possibility of such damage.
40 //
41 //M*/
42 
43 #include "precomp.hpp"
44 namespace cv {
45 namespace ml {
46 
47 //////////////////////////////////////////////////////////////////////////////////////////
48 //                                  Random trees                                        //
49 //////////////////////////////////////////////////////////////////////////////////////////
RTreeParams()50 RTreeParams::RTreeParams()
51 {
52     CV_TRACE_FUNCTION();
53     calcVarImportance = false;
54     nactiveVars = 0;
55     termCrit = TermCriteria(TermCriteria::EPS + TermCriteria::COUNT, 50, 0.1);
56 }
57 
RTreeParams(bool _calcVarImportance,int _nactiveVars,TermCriteria _termCrit)58 RTreeParams::RTreeParams(bool _calcVarImportance,
59                          int _nactiveVars,
60                          TermCriteria _termCrit )
61 {
62     CV_TRACE_FUNCTION();
63     calcVarImportance = _calcVarImportance;
64     nactiveVars = _nactiveVars;
65     termCrit = _termCrit;
66 }
67 
68 
69 class DTreesImplForRTrees CV_FINAL : public DTreesImpl
70 {
71 public:
DTreesImplForRTrees()72     DTreesImplForRTrees()
73     {
74         CV_TRACE_FUNCTION();
75         params.setMaxDepth(5);
76         params.setMinSampleCount(10);
77         params.setRegressionAccuracy(0.f);
78         params.useSurrogates = false;
79         params.setMaxCategories(10);
80         params.setCVFolds(0);
81         params.use1SERule = false;
82         params.truncatePrunedTree = false;
83         params.priors = Mat();
84         oobError = 0;
85     }
~DTreesImplForRTrees()86     virtual ~DTreesImplForRTrees() {}
87 
clear()88     void clear() CV_OVERRIDE
89     {
90         CV_TRACE_FUNCTION();
91         DTreesImpl::clear();
92         oobError = 0.;
93     }
94 
getActiveVars()95     const vector<int>& getActiveVars() CV_OVERRIDE
96     {
97         CV_TRACE_FUNCTION();
98         RNG &rng = theRNG();
99         int i, nvars = (int)allVars.size(), m = (int)activeVars.size();
100         for( i = 0; i < nvars; i++ )
101         {
102             int i1 = rng.uniform(0, nvars);
103             int i2 = rng.uniform(0, nvars);
104             std::swap(allVars[i1], allVars[i2]);
105         }
106         for( i = 0; i < m; i++ )
107             activeVars[i] = allVars[i];
108         return activeVars;
109     }
110 
startTraining(const Ptr<TrainData> & trainData,int flags)111     void startTraining( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE
112     {
113         CV_TRACE_FUNCTION();
114         CV_Assert(!trainData.empty());
115         DTreesImpl::startTraining(trainData, flags);
116         int nvars = w->data->getNVars();
117         int i, m = rparams.nactiveVars > 0 ? rparams.nactiveVars : cvRound(std::sqrt((double)nvars));
118         m = std::min(std::max(m, 1), nvars);
119         allVars.resize(nvars);
120         activeVars.resize(m);
121         for( i = 0; i < nvars; i++ )
122             allVars[i] = varIdx[i];
123     }
124 
endTraining()125     void endTraining() CV_OVERRIDE
126     {
127         CV_TRACE_FUNCTION();
128         DTreesImpl::endTraining();
129         vector<int> a, b;
130         std::swap(allVars, a);
131         std::swap(activeVars, b);
132     }
133 
train(const Ptr<TrainData> & trainData,int flags)134     bool train( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE
135     {
136         CV_TRACE_FUNCTION();
137         RNG &rng = theRNG();
138         CV_Assert(!trainData.empty());
139         startTraining(trainData, flags);
140         int treeidx, ntrees = (rparams.termCrit.type & TermCriteria::COUNT) != 0 ?
141             rparams.termCrit.maxCount : 10000;
142         int i, j, k, vi, vi_, n = (int)w->sidx.size();
143         int nclasses = (int)classLabels.size();
144         double eps = (rparams.termCrit.type & TermCriteria::EPS) != 0 &&
145             rparams.termCrit.epsilon > 0 ? rparams.termCrit.epsilon : 0.;
146         vector<int> sidx(n);
147         vector<uchar> oobmask(n);
148         vector<int> oobidx;
149         vector<int> oobperm;
150         vector<double> oobres(n, 0.);
151         vector<int> oobcount(n, 0);
152         vector<int> oobvotes(n*nclasses, 0);
153         int nvars = w->data->getNVars();
154         int nallvars = w->data->getNAllVars();
155         const int* vidx = !varIdx.empty() ? &varIdx[0] : 0;
156         vector<float> samplebuf(nallvars);
157         Mat samples = w->data->getSamples();
158         float* psamples = samples.ptr<float>();
159         size_t sstep0 = samples.step1(), sstep1 = 1;
160         Mat sample0, sample(nallvars, 1, CV_32F, &samplebuf[0]);
161         int predictFlags = _isClassifier ? (PREDICT_MAX_VOTE + RAW_OUTPUT) : PREDICT_SUM;
162 
163         bool calcOOBError = eps > 0 || rparams.calcVarImportance;
164         double max_response = 0.;
165 
166         if( w->data->getLayout() == COL_SAMPLE )
167             std::swap(sstep0, sstep1);
168 
169         if( !_isClassifier )
170         {
171             for( i = 0; i < n; i++ )
172             {
173                 double val = std::abs(w->ord_responses[w->sidx[i]]);
174                 max_response = std::max(max_response, val);
175             }
176             CV_Assert(fabs(max_response) > 0);
177         }
178 
179         if( rparams.calcVarImportance )
180             varImportance.resize(nallvars, 0.f);
181 
182         for( treeidx = 0; treeidx < ntrees; treeidx++ )
183         {
184             for( i = 0; i < n; i++ )
185                 oobmask[i] = (uchar)1;
186 
187             for( i = 0; i < n; i++ )
188             {
189                 j = rng.uniform(0, n);
190                 sidx[i] = w->sidx[j];
191                 oobmask[j] = (uchar)0;
192             }
193             int root = addTree( sidx );
194             if( root < 0 )
195                 return false;
196 
197             if( calcOOBError )
198             {
199                 oobidx.clear();
200                 for( i = 0; i < n; i++ )
201                 {
202                     if( oobmask[i] )
203                         oobidx.push_back(i);
204                 }
205                 int n_oob = (int)oobidx.size();
206                 // if there is no out-of-bag samples, we can not compute OOB error
207                 // nor update the variable importance vector; so we proceed to the next tree
208                 if( n_oob == 0 )
209                     continue;
210                 double ncorrect_responses = 0.;
211 
212                 oobError = 0.;
213                 for( i = 0; i < n_oob; i++ )
214                 {
215                     j = oobidx[i];
216                     sample = Mat( nallvars, 1, CV_32F, psamples + sstep0*w->sidx[j], sstep1*sizeof(psamples[0]) );
217 
218                     double val = predictTrees(Range(treeidx, treeidx+1), sample, predictFlags);
219                     double sample_weight = w->sample_weights[w->sidx[j]];
220                     if( !_isClassifier )
221                     {
222                         oobres[j] += val;
223                         oobcount[j]++;
224                         double true_val = w->ord_responses[w->sidx[j]];
225                         double a = oobres[j]/oobcount[j] - true_val;
226                         oobError += sample_weight * a*a;
227                         val = (val - true_val)/max_response;
228                         ncorrect_responses += std::exp( -val*val );
229                     }
230                     else
231                     {
232                         int ival = cvRound(val);
233                         //Voting scheme to combine OOB errors of each tree
234                         int* votes = &oobvotes[j*nclasses];
235                         votes[ival]++;
236                         int best_class = 0;
237                         for( k = 1; k < nclasses; k++ )
238                             if( votes[best_class] < votes[k] )
239                                 best_class = k;
240                         int diff = best_class != w->cat_responses[w->sidx[j]];
241                         oobError += sample_weight * diff;
242                         ncorrect_responses += diff == 0;
243                     }
244                 }
245 
246                 oobError /= n_oob;
247                 if( rparams.calcVarImportance && n_oob > 1 )
248                 {
249                     Mat sample_clone;
250                     oobperm.resize(n_oob);
251                     for( i = 0; i < n_oob; i++ )
252                         oobperm[i] = oobidx[i];
253                     for (i = n_oob - 1; i > 0; --i)  //Randomly shuffle indices so we can permute features
254                     {
255                         int r_i = rng.uniform(0, n_oob);
256                         std::swap(oobperm[i], oobperm[r_i]);
257                     }
258 
259                     for( vi_ = 0; vi_ < nvars; vi_++ )
260                     {
261                         vi = vidx ? vidx[vi_] : vi_; //Ensure that only the user specified predictors are used for training
262                         double ncorrect_responses_permuted = 0;
263 
264                         for( i = 0; i < n_oob; i++ )
265                         {
266                             j = oobidx[i];
267                             int vj = oobperm[i];
268                             sample0 = Mat( nallvars, 1, CV_32F, psamples + sstep0*w->sidx[j], sstep1*sizeof(psamples[0]) );
269                             sample0.copyTo(sample_clone); //create a copy so we don't mess up the original data
270                             sample_clone.at<float>(vi) = psamples[sstep0*w->sidx[vj] + sstep1*vi];
271 
272                             double val = predictTrees(Range(treeidx, treeidx+1), sample_clone, predictFlags);
273                             if( !_isClassifier )
274                             {
275                                 val = (val - w->ord_responses[w->sidx[j]])/max_response;
276                                 ncorrect_responses_permuted += exp( -val*val );
277                             }
278                             else
279                             {
280                                 ncorrect_responses_permuted += cvRound(val) == w->cat_responses[w->sidx[j]];
281                             }
282                         }
283                         varImportance[vi] += (float)(ncorrect_responses - ncorrect_responses_permuted);
284                     }
285                 }
286             }
287             if( calcOOBError && oobError < eps )
288                 break;
289         }
290 
291         if( rparams.calcVarImportance )
292         {
293             for( vi_ = 0; vi_ < nallvars; vi_++ )
294                 varImportance[vi_] = std::max(varImportance[vi_], 0.f);
295             normalize(varImportance, varImportance, 1., 0, NORM_L1);
296         }
297         endTraining();
298         return true;
299     }
300 
writeTrainingParams(FileStorage & fs) const301     void writeTrainingParams( FileStorage& fs ) const CV_OVERRIDE
302     {
303         CV_TRACE_FUNCTION();
304         DTreesImpl::writeTrainingParams(fs);
305         fs << "nactive_vars" << rparams.nactiveVars;
306     }
307 
write(FileStorage & fs) const308     void write( FileStorage& fs ) const CV_OVERRIDE
309     {
310         CV_TRACE_FUNCTION();
311         if( roots.empty() )
312             CV_Error( CV_StsBadArg, "RTrees have not been trained" );
313 
314         writeFormat(fs);
315         writeParams(fs);
316 
317         fs << "oob_error" << oobError;
318         if( !varImportance.empty() )
319             fs << "var_importance" << varImportance;
320 
321         int k, ntrees = (int)roots.size();
322 
323         fs << "ntrees" << ntrees
324            << "trees" << "[";
325 
326         for( k = 0; k < ntrees; k++ )
327         {
328             fs << "{";
329             writeTree(fs, roots[k]);
330             fs << "}";
331         }
332 
333         fs << "]";
334     }
335 
readParams(const FileNode & fn)336     void readParams( const FileNode& fn ) CV_OVERRIDE
337     {
338         CV_TRACE_FUNCTION();
339         DTreesImpl::readParams(fn);
340 
341         FileNode tparams_node = fn["training_params"];
342         rparams.nactiveVars = (int)tparams_node["nactive_vars"];
343     }
344 
read(const FileNode & fn)345     void read( const FileNode& fn ) CV_OVERRIDE
346     {
347         CV_TRACE_FUNCTION();
348         clear();
349 
350         //int nclasses = (int)fn["nclasses"];
351         //int nsamples = (int)fn["nsamples"];
352         oobError = (double)fn["oob_error"];
353         int ntrees = (int)fn["ntrees"];
354 
355         readVectorOrMat(fn["var_importance"], varImportance);
356 
357         readParams(fn);
358 
359         FileNode trees_node = fn["trees"];
360         FileNodeIterator it = trees_node.begin();
361         CV_Assert( ntrees == (int)trees_node.size() );
362 
363         for( int treeidx = 0; treeidx < ntrees; treeidx++, ++it )
364         {
365             FileNode nfn = (*it)["nodes"];
366             readTree(nfn);
367         }
368     }
369 
getVotes(InputArray input,OutputArray output,int flags) const370     void getVotes( InputArray input, OutputArray output, int flags ) const
371     {
372         CV_TRACE_FUNCTION();
373         CV_Assert( !roots.empty() );
374         int nclasses = (int)classLabels.size(), ntrees = (int)roots.size();
375         Mat samples = input.getMat(), results;
376         int i, j, nsamples = samples.rows;
377 
378         int predictType = flags & PREDICT_MASK;
379         if( predictType == PREDICT_AUTO )
380         {
381             predictType = !_isClassifier || (classLabels.size() == 2 && (flags & RAW_OUTPUT) != 0) ?
382                 PREDICT_SUM : PREDICT_MAX_VOTE;
383         }
384 
385         if( predictType == PREDICT_SUM )
386         {
387             output.create(nsamples, ntrees, CV_32F);
388             results = output.getMat();
389             for( i = 0; i < nsamples; i++ )
390             {
391                 for( j = 0; j < ntrees; j++ )
392                 {
393                     float val = predictTrees( Range(j, j+1), samples.row(i), flags);
394                     results.at<float> (i, j) = val;
395                 }
396             }
397         } else
398         {
399             vector<int> votes;
400             output.create(nsamples+1, nclasses, CV_32S);
401             results = output.getMat();
402 
403             for ( j = 0; j < nclasses; j++)
404             {
405                 results.at<int> (0, j) = classLabels[j];
406             }
407 
408             for( i = 0; i < nsamples; i++ )
409             {
410                 votes.clear();
411                 for( j = 0; j < ntrees; j++ )
412                 {
413                     int val = (int)predictTrees( Range(j, j+1), samples.row(i), flags);
414                     votes.push_back(val);
415                 }
416 
417                 for ( j = 0; j < nclasses; j++)
418                 {
419                     results.at<int> (i+1, j) = (int)std::count(votes.begin(), votes.end(), classLabels[j]);
420                 }
421             }
422         }
423     }
424 
getOOBError() const425     double getOOBError() const {
426         return oobError;
427     }
428 
429     RTreeParams rparams;
430     double oobError;
431     vector<float> varImportance;
432     vector<int> allVars, activeVars;
433 };
434 
435 
436 class RTreesImpl CV_FINAL : public RTrees
437 {
438 public:
getCalculateVarImportance() const439     inline bool getCalculateVarImportance() const CV_OVERRIDE { return impl.rparams.calcVarImportance; }
setCalculateVarImportance(bool val)440     inline void setCalculateVarImportance(bool val) CV_OVERRIDE { impl.rparams.calcVarImportance = val; }
getActiveVarCount() const441     inline int getActiveVarCount() const CV_OVERRIDE { return impl.rparams.nactiveVars; }
setActiveVarCount(int val)442     inline void setActiveVarCount(int val) CV_OVERRIDE { impl.rparams.nactiveVars = val; }
getTermCriteria() const443     inline TermCriteria getTermCriteria() const CV_OVERRIDE { return impl.rparams.termCrit; }
setTermCriteria(const TermCriteria & val)444     inline void setTermCriteria(const TermCriteria& val) CV_OVERRIDE { impl.rparams.termCrit = val; }
445 
getMaxCategories() const446     inline int getMaxCategories() const CV_OVERRIDE { return impl.params.getMaxCategories(); }
setMaxCategories(int val)447     inline void setMaxCategories(int val) CV_OVERRIDE { impl.params.setMaxCategories(val); }
getMaxDepth() const448     inline int getMaxDepth() const CV_OVERRIDE { return impl.params.getMaxDepth(); }
setMaxDepth(int val)449     inline void setMaxDepth(int val) CV_OVERRIDE { impl.params.setMaxDepth(val); }
getMinSampleCount() const450     inline int getMinSampleCount() const CV_OVERRIDE { return impl.params.getMinSampleCount(); }
setMinSampleCount(int val)451     inline void setMinSampleCount(int val) CV_OVERRIDE { impl.params.setMinSampleCount(val); }
getCVFolds() const452     inline int getCVFolds() const CV_OVERRIDE { return impl.params.getCVFolds(); }
setCVFolds(int val)453     inline void setCVFolds(int val) CV_OVERRIDE { impl.params.setCVFolds(val); }
getUseSurrogates() const454     inline bool getUseSurrogates() const CV_OVERRIDE { return impl.params.getUseSurrogates(); }
setUseSurrogates(bool val)455     inline void setUseSurrogates(bool val) CV_OVERRIDE { impl.params.setUseSurrogates(val); }
getUse1SERule() const456     inline bool getUse1SERule() const CV_OVERRIDE { return impl.params.getUse1SERule(); }
setUse1SERule(bool val)457     inline void setUse1SERule(bool val) CV_OVERRIDE { impl.params.setUse1SERule(val); }
getTruncatePrunedTree() const458     inline bool getTruncatePrunedTree() const CV_OVERRIDE { return impl.params.getTruncatePrunedTree(); }
setTruncatePrunedTree(bool val)459     inline void setTruncatePrunedTree(bool val) CV_OVERRIDE { impl.params.setTruncatePrunedTree(val); }
getRegressionAccuracy() const460     inline float getRegressionAccuracy() const CV_OVERRIDE { return impl.params.getRegressionAccuracy(); }
setRegressionAccuracy(float val)461     inline void setRegressionAccuracy(float val) CV_OVERRIDE { impl.params.setRegressionAccuracy(val); }
getPriors() const462     inline cv::Mat getPriors() const CV_OVERRIDE { return impl.params.getPriors(); }
setPriors(const cv::Mat & val)463     inline void setPriors(const cv::Mat& val) CV_OVERRIDE { impl.params.setPriors(val); }
getVotes(InputArray input,OutputArray output,int flags) const464     inline void getVotes(InputArray input, OutputArray output, int flags) const CV_OVERRIDE {return impl.getVotes(input,output,flags);}
465 
RTreesImpl()466     RTreesImpl() {}
~RTreesImpl()467     virtual ~RTreesImpl() CV_OVERRIDE {}
468 
getDefaultName() const469     String getDefaultName() const CV_OVERRIDE { return "opencv_ml_rtrees"; }
470 
train(const Ptr<TrainData> & trainData,int flags)471     bool train( const Ptr<TrainData>& trainData, int flags ) CV_OVERRIDE
472     {
473         CV_TRACE_FUNCTION();
474         CV_Assert(!trainData.empty());
475         if (impl.getCVFolds() != 0)
476             CV_Error(Error::StsBadArg, "Cross validation for RTrees is not implemented");
477         return impl.train(trainData, flags);
478     }
479 
predict(InputArray samples,OutputArray results,int flags) const480     float predict( InputArray samples, OutputArray results, int flags ) const CV_OVERRIDE
481     {
482         CV_TRACE_FUNCTION();
483         CV_CheckEQ(samples.cols(), getVarCount(), "");
484         return impl.predict(samples, results, flags);
485     }
486 
write(FileStorage & fs) const487     void write( FileStorage& fs ) const CV_OVERRIDE
488     {
489         CV_TRACE_FUNCTION();
490         impl.write(fs);
491     }
492 
read(const FileNode & fn)493     void read( const FileNode& fn ) CV_OVERRIDE
494     {
495         CV_TRACE_FUNCTION();
496         impl.read(fn);
497     }
498 
getVarImportance() const499     Mat getVarImportance() const CV_OVERRIDE { return Mat_<float>(impl.varImportance, true); }
getVarCount() const500     int getVarCount() const CV_OVERRIDE { return impl.getVarCount(); }
501 
isTrained() const502     bool isTrained() const CV_OVERRIDE { return impl.isTrained(); }
isClassifier() const503     bool isClassifier() const CV_OVERRIDE { return impl.isClassifier(); }
504 
getRoots() const505     const vector<int>& getRoots() const CV_OVERRIDE { return impl.getRoots(); }
getNodes() const506     const vector<Node>& getNodes() const CV_OVERRIDE { return impl.getNodes(); }
getSplits() const507     const vector<Split>& getSplits() const CV_OVERRIDE { return impl.getSplits(); }
getSubsets() const508     const vector<int>& getSubsets() const CV_OVERRIDE { return impl.getSubsets(); }
getOOBError() const509     double getOOBError() const CV_OVERRIDE { return impl.getOOBError(); }
510 
511 
512     DTreesImplForRTrees impl;
513 };
514 
515 
create()516 Ptr<RTrees> RTrees::create()
517 {
518     CV_TRACE_FUNCTION();
519     return makePtr<RTreesImpl>();
520 }
521 
522 //Function needed for Python and Java wrappers
load(const String & filepath,const String & nodeName)523 Ptr<RTrees> RTrees::load(const String& filepath, const String& nodeName)
524 {
525     CV_TRACE_FUNCTION();
526     return Algorithm::load<RTrees>(filepath, nodeName);
527 }
528 
529 }}
530 
531 // End of file.
532