1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                           License Agreement
11 //                For Open Source Computer Vision Library
12 //
13 // Copyright (C) 2000, Intel Corporation, all rights reserved.
14 // Copyright (C) 2014, Itseez Inc, all rights reserved.
15 // Third party copyrights are property of their respective owners.
16 //
17 // Redistribution and use in source and binary forms, with or without modification,
18 // are permitted provided that the following conditions are met:
19 //
20 //   * Redistribution's of source code must retain the above copyright notice,
21 //     this list of conditions and the following disclaimer.
22 //
23 //   * Redistribution's in binary form must reproduce the above copyright notice,
24 //     this list of conditions and the following disclaimer in the documentation
25 //     and/or other materials provided with the distribution.
26 //
27 //   * The name of the copyright holders may not be used to endorse or promote products
28 //     derived from this software without specific prior written permission.
29 //
30 // This software is provided by the copyright holders and contributors "as is" and
31 // any express or implied warranties, including, but not limited to, the implied
32 // warranties of merchantability and fitness for a particular purpose are disclaimed.
33 // In no event shall the Intel Corporation or contributors be liable for any direct,
34 // indirect, incidental, special, exemplary, or consequential damages
35 // (including, but not limited to, procurement of substitute goods or services;
36 // loss of use, data, or profits; or business interruption) however caused
37 // and on any theory of liability, whether in contract, strict liability,
38 // or tort (including negligence or otherwise) arising in any way out of
39 // the use of this software, even if advised of the possibility of such damage.
40 //
41 //M*/
42 
43 #include "precomp.hpp"
44 #include <ctype.h>
45 
46 #include <opencv2/core/utils/logger.hpp>
47 
48 namespace cv {
49 namespace ml {
50 
51 using std::vector;
52 
TreeParams()53 TreeParams::TreeParams()
54 {
55     maxDepth = INT_MAX;
56     minSampleCount = 10;
57     regressionAccuracy = 0.01f;
58     useSurrogates = false;
59     maxCategories = 10;
60     CVFolds = 10;
61     use1SERule = true;
62     truncatePrunedTree = true;
63     priors = Mat();
64 }
65 
TreeParams(int _maxDepth,int _minSampleCount,double _regressionAccuracy,bool _useSurrogates,int _maxCategories,int _CVFolds,bool _use1SERule,bool _truncatePrunedTree,const Mat & _priors)66 TreeParams::TreeParams(int _maxDepth, int _minSampleCount,
67                        double _regressionAccuracy, bool _useSurrogates,
68                        int _maxCategories, int _CVFolds,
69                        bool _use1SERule, bool _truncatePrunedTree,
70                        const Mat& _priors)
71 {
72     maxDepth = _maxDepth;
73     minSampleCount = _minSampleCount;
74     regressionAccuracy = (float)_regressionAccuracy;
75     useSurrogates = _useSurrogates;
76     maxCategories = _maxCategories;
77     CVFolds = _CVFolds;
78     use1SERule = _use1SERule;
79     truncatePrunedTree = _truncatePrunedTree;
80     priors = _priors;
81 }
82 
Node()83 DTrees::Node::Node()
84 {
85     classIdx = 0;
86     value = 0;
87     parent = left = right = split = defaultDir = -1;
88 }
89 
Split()90 DTrees::Split::Split()
91 {
92     varIdx = 0;
93     inversed = false;
94     quality = 0.f;
95     next = -1;
96     c = 0.f;
97     subsetOfs = 0;
98 }
99 
100 
WorkData(const Ptr<TrainData> & _data)101 DTreesImpl::WorkData::WorkData(const Ptr<TrainData>& _data)
102 {
103     CV_Assert(!_data.empty());
104     data = _data;
105     vector<int> subsampleIdx;
106     Mat sidx0 = _data->getTrainSampleIdx();
107     if( !sidx0.empty() )
108     {
109         sidx0.copyTo(sidx);
110         std::sort(sidx.begin(), sidx.end());
111     }
112     else
113     {
114         int n = _data->getNSamples();
115         setRangeVector(sidx, n);
116     }
117 
118     maxSubsetSize = 0;
119 }
120 
DTreesImpl()121 DTreesImpl::DTreesImpl() : _isClassifier(false) {}
~DTreesImpl()122 DTreesImpl::~DTreesImpl() {}
clear()123 void DTreesImpl::clear()
124 {
125     varIdx.clear();
126     compVarIdx.clear();
127     varType.clear();
128     catOfs.clear();
129     catMap.clear();
130     roots.clear();
131     nodes.clear();
132     splits.clear();
133     subsets.clear();
134     classLabels.clear();
135 
136     w.release();
137     _isClassifier = false;
138 }
139 
startTraining(const Ptr<TrainData> & data,int)140 void DTreesImpl::startTraining( const Ptr<TrainData>& data, int )
141 {
142     CV_Assert(!data.empty());
143     clear();
144     w = makePtr<WorkData>(data);
145 
146     Mat vtype = data->getVarType();
147     vtype.copyTo(varType);
148 
149     data->getCatOfs().copyTo(catOfs);
150     data->getCatMap().copyTo(catMap);
151     data->getDefaultSubstValues().copyTo(missingSubst);
152 
153     int nallvars = data->getNAllVars();
154 
155     Mat vidx0 = data->getVarIdx();
156     if( !vidx0.empty() )
157         vidx0.copyTo(varIdx);
158     else
159         setRangeVector(varIdx, nallvars);
160 
161     initCompVarIdx();
162 
163     w->maxSubsetSize = 0;
164 
165     int i, nvars = (int)varIdx.size();
166     for( i = 0; i < nvars; i++ )
167         w->maxSubsetSize = std::max(w->maxSubsetSize, getCatCount(varIdx[i]));
168 
169     w->maxSubsetSize = std::max((w->maxSubsetSize + 31)/32, 1);
170 
171     data->getSampleWeights().copyTo(w->sample_weights);
172 
173     _isClassifier = data->getResponseType() == VAR_CATEGORICAL;
174 
175     if( _isClassifier )
176     {
177         data->getNormCatResponses().copyTo(w->cat_responses);
178         data->getClassLabels().copyTo(classLabels);
179         int nclasses = (int)classLabels.size();
180 
181         Mat class_weights = params.priors;
182         if( !class_weights.empty() )
183         {
184             if( class_weights.type() != CV_64F || !class_weights.isContinuous() )
185             {
186                 Mat temp;
187                 class_weights.convertTo(temp, CV_64F);
188                 class_weights = temp;
189             }
190             CV_Assert( class_weights.checkVector(1, CV_64F) == nclasses );
191 
192             int nsamples = (int)w->cat_responses.size();
193             const double* cw = class_weights.ptr<double>();
194             CV_Assert( (int)w->sample_weights.size() == nsamples );
195 
196             for( i = 0; i < nsamples; i++ )
197             {
198                 int ci = w->cat_responses[i];
199                 CV_Assert( 0 <= ci && ci < nclasses );
200                 w->sample_weights[i] *= cw[ci];
201             }
202         }
203     }
204     else
205         data->getResponses().copyTo(w->ord_responses);
206 }
207 
208 
initCompVarIdx()209 void DTreesImpl::initCompVarIdx()
210 {
211     int nallvars = (int)varType.size();
212     compVarIdx.assign(nallvars, -1);
213     int i, nvars = (int)varIdx.size(), prevIdx = -1;
214     for( i = 0; i < nvars; i++ )
215     {
216         int vi = varIdx[i];
217         CV_Assert( 0 <= vi && vi < nallvars && vi > prevIdx );
218         prevIdx = vi;
219         compVarIdx[vi] = i;
220     }
221 }
222 
endTraining()223 void DTreesImpl::endTraining()
224 {
225     w.release();
226 }
227 
train(const Ptr<TrainData> & trainData,int flags)228 bool DTreesImpl::train( const Ptr<TrainData>& trainData, int flags )
229 {
230     CV_Assert(!trainData.empty());
231     startTraining(trainData, flags);
232     bool ok = addTree( w->sidx ) >= 0;
233     w.release();
234     endTraining();
235     return ok;
236 }
237 
getActiveVars()238 const vector<int>& DTreesImpl::getActiveVars()
239 {
240     return varIdx;
241 }
242 
addTree(const vector<int> & sidx)243 int DTreesImpl::addTree(const vector<int>& sidx )
244 {
245     size_t n = (params.getMaxDepth() > 0 ? (1 << params.getMaxDepth()) : 1024) + w->wnodes.size();
246 
247     w->wnodes.reserve(n);
248     w->wsplits.reserve(n);
249     w->wsubsets.reserve(n*w->maxSubsetSize);
250     w->wnodes.clear();
251     w->wsplits.clear();
252     w->wsubsets.clear();
253 
254     int cv_n = params.getCVFolds();
255 
256     if( cv_n > 0 )
257     {
258         w->cv_Tn.resize(n*cv_n);
259         w->cv_node_error.resize(n*cv_n);
260         w->cv_node_risk.resize(n*cv_n);
261     }
262 
263     // build the tree recursively
264     int w_root = addNodeAndTrySplit(-1, sidx);
265     int maxdepth = INT_MAX;//pruneCV(root);
266 
267     int w_nidx = w_root, pidx = -1, depth = 0;
268     int root = (int)nodes.size();
269 
270     for(;;)
271     {
272         const WNode& wnode = w->wnodes[w_nidx];
273         Node node;
274         node.parent = pidx;
275         node.classIdx = wnode.class_idx;
276         node.value = wnode.value;
277         node.defaultDir = wnode.defaultDir;
278 
279         int wsplit_idx = wnode.split;
280         if( wsplit_idx >= 0 )
281         {
282             const WSplit& wsplit = w->wsplits[wsplit_idx];
283             Split split;
284             split.c = wsplit.c;
285             split.quality = wsplit.quality;
286             split.inversed = wsplit.inversed;
287             split.varIdx = wsplit.varIdx;
288             split.subsetOfs = -1;
289             if( wsplit.subsetOfs >= 0 )
290             {
291                 int ssize = getSubsetSize(split.varIdx);
292                 split.subsetOfs = (int)subsets.size();
293                 subsets.resize(split.subsetOfs + ssize);
294                 // This check verifies that subsets index is in the correct range
295                 // as in case ssize == 0 no real resize performed.
296                 // Thus memory kept safe.
297                 // Also this skips useless memcpy call when size parameter is zero
298                 if(ssize > 0)
299                 {
300                     memcpy(&subsets[split.subsetOfs], &w->wsubsets[wsplit.subsetOfs], ssize*sizeof(int));
301                 }
302             }
303             node.split = (int)splits.size();
304             splits.push_back(split);
305         }
306         int nidx = (int)nodes.size();
307         nodes.push_back(node);
308         if( pidx >= 0 )
309         {
310             int w_pidx = w->wnodes[w_nidx].parent;
311             if( w->wnodes[w_pidx].left == w_nidx )
312             {
313                 nodes[pidx].left = nidx;
314             }
315             else
316             {
317                 CV_Assert(w->wnodes[w_pidx].right == w_nidx);
318                 nodes[pidx].right = nidx;
319             }
320         }
321 
322         if( wnode.left >= 0 && depth+1 < maxdepth )
323         {
324             w_nidx = wnode.left;
325             pidx = nidx;
326             depth++;
327         }
328         else
329         {
330             int w_pidx = wnode.parent;
331             while( w_pidx >= 0 && w->wnodes[w_pidx].right == w_nidx )
332             {
333                 w_nidx = w_pidx;
334                 w_pidx = w->wnodes[w_pidx].parent;
335                 nidx = pidx;
336                 pidx = nodes[pidx].parent;
337                 depth--;
338             }
339 
340             if( w_pidx < 0 )
341                 break;
342 
343             w_nidx = w->wnodes[w_pidx].right;
344             CV_Assert( w_nidx >= 0 );
345         }
346     }
347     roots.push_back(root);
348     return root;
349 }
350 
setDParams(const TreeParams & _params)351 void DTreesImpl::setDParams(const TreeParams& _params)
352 {
353     params = _params;
354 }
355 
addNodeAndTrySplit(int parent,const vector<int> & sidx)356 int DTreesImpl::addNodeAndTrySplit( int parent, const vector<int>& sidx )
357 {
358     w->wnodes.push_back(WNode());
359     int nidx = (int)(w->wnodes.size() - 1);
360     WNode& node = w->wnodes.back();
361 
362     node.parent = parent;
363     node.depth = parent >= 0 ? w->wnodes[parent].depth + 1 : 0;
364     int nfolds = params.getCVFolds();
365 
366     if( nfolds > 0 )
367     {
368         w->cv_Tn.resize((nidx+1)*nfolds);
369         w->cv_node_error.resize((nidx+1)*nfolds);
370         w->cv_node_risk.resize((nidx+1)*nfolds);
371     }
372 
373     int i, n = node.sample_count = (int)sidx.size();
374     bool can_split = true;
375     vector<int> sleft, sright;
376 
377     calcValue( nidx, sidx );
378 
379     if( n <= params.getMinSampleCount() || node.depth >= params.getMaxDepth() )
380         can_split = false;
381     else if( _isClassifier )
382     {
383         const int* responses = &w->cat_responses[0];
384         const int* s = &sidx[0];
385         int first = responses[s[0]];
386         for( i = 1; i < n; i++ )
387             if( responses[s[i]] != first )
388                 break;
389         if( i == n )
390             can_split = false;
391     }
392     else
393     {
394         if( sqrt(node.node_risk) < params.getRegressionAccuracy() )
395             can_split = false;
396     }
397 
398     if( can_split )
399         node.split = findBestSplit( sidx );
400 
401     //printf("depth=%d, nidx=%d, parent=%d, n=%d, %s, value=%.1f, risk=%.1f\n", node.depth, nidx, node.parent, n, (node.split < 0 ? "leaf" : varType[w->wsplits[node.split].varIdx] == VAR_CATEGORICAL ? "cat" : "ord"), node.value, node.node_risk);
402 
403     if( node.split >= 0 )
404     {
405         node.defaultDir = calcDir( node.split, sidx, sleft, sright );
406         if( params.useSurrogates )
407             CV_Error( CV_StsNotImplemented, "surrogate splits are not implemented yet");
408 
409         int left = addNodeAndTrySplit( nidx, sleft );
410         int right = addNodeAndTrySplit( nidx, sright );
411         w->wnodes[nidx].left = left;
412         w->wnodes[nidx].right = right;
413         CV_Assert( w->wnodes[nidx].left > 0 && w->wnodes[nidx].right > 0 );
414     }
415 
416     return nidx;
417 }
418 
findBestSplit(const vector<int> & _sidx)419 int DTreesImpl::findBestSplit( const vector<int>& _sidx )
420 {
421     const vector<int>& activeVars = getActiveVars();
422     int splitidx = -1;
423     int vi_, nv = (int)activeVars.size();
424     AutoBuffer<int> buf(w->maxSubsetSize*2);
425     int *subset = buf.data(), *best_subset = subset + w->maxSubsetSize;
426     WSplit split, best_split;
427     best_split.quality = 0.;
428 
429     for( vi_ = 0; vi_ < nv; vi_++ )
430     {
431         int vi = activeVars[vi_];
432         if( varType[vi] == VAR_CATEGORICAL )
433         {
434             if( _isClassifier )
435                 split = findSplitCatClass(vi, _sidx, 0, subset);
436             else
437                 split = findSplitCatReg(vi, _sidx, 0, subset);
438         }
439         else
440         {
441             if( _isClassifier )
442                 split = findSplitOrdClass(vi, _sidx, 0);
443             else
444                 split = findSplitOrdReg(vi, _sidx, 0);
445         }
446         if( split.quality > best_split.quality )
447         {
448             best_split = split;
449             std::swap(subset, best_subset);
450         }
451     }
452 
453     if( best_split.quality > 0 )
454     {
455         int best_vi = best_split.varIdx;
456         CV_Assert( compVarIdx[best_split.varIdx] >= 0 && best_vi >= 0 );
457         int i, prevsz = (int)w->wsubsets.size(), ssize = getSubsetSize(best_vi);
458         w->wsubsets.resize(prevsz + ssize);
459         for( i = 0; i < ssize; i++ )
460             w->wsubsets[prevsz + i] = best_subset[i];
461         best_split.subsetOfs = prevsz;
462         w->wsplits.push_back(best_split);
463         splitidx = (int)(w->wsplits.size()-1);
464     }
465 
466     return splitidx;
467 }
468 
calcValue(int nidx,const vector<int> & _sidx)469 void DTreesImpl::calcValue( int nidx, const vector<int>& _sidx )
470 {
471     WNode* node = &w->wnodes[nidx];
472     int i, j, k, n = (int)_sidx.size(), cv_n = params.getCVFolds();
473     int m = (int)classLabels.size();
474 
475     cv::AutoBuffer<double> buf(std::max(m, 3)*(cv_n+1));
476 
477     if( cv_n > 0 )
478     {
479         size_t sz = w->cv_Tn.size();
480         w->cv_Tn.resize(sz + cv_n);
481         w->cv_node_risk.resize(sz + cv_n);
482         w->cv_node_error.resize(sz + cv_n);
483     }
484 
485     if( _isClassifier )
486     {
487         // in case of classification tree:
488         //  * node value is the label of the class that has the largest weight in the node.
489         //  * node risk is the weighted number of misclassified samples,
490         //  * j-th cross-validation fold value and risk are calculated as above,
491         //    but using the samples with cv_labels(*)!=j.
492         //  * j-th cross-validation fold error is calculated as the weighted number of
493         //    misclassified samples with cv_labels(*)==j.
494 
495         // compute the number of instances of each class
496         double* cls_count = buf.data();
497         double* cv_cls_count = cls_count + m;
498 
499         double max_val = -1, total_weight = 0;
500         int max_k = -1;
501 
502         for( k = 0; k < m; k++ )
503             cls_count[k] = 0;
504 
505         if( cv_n == 0 )
506         {
507             for( i = 0; i < n; i++ )
508             {
509                 int si = _sidx[i];
510                 cls_count[w->cat_responses[si]] += w->sample_weights[si];
511             }
512         }
513         else
514         {
515             for( j = 0; j < cv_n; j++ )
516                 for( k = 0; k < m; k++ )
517                     cv_cls_count[j*m + k] = 0;
518 
519             for( i = 0; i < n; i++ )
520             {
521                 int si = _sidx[i];
522                 j = w->cv_labels[si]; k = w->cat_responses[si];
523                 cv_cls_count[j*m + k] += w->sample_weights[si];
524             }
525 
526             for( j = 0; j < cv_n; j++ )
527                 for( k = 0; k < m; k++ )
528                     cls_count[k] += cv_cls_count[j*m + k];
529         }
530 
531         for( k = 0; k < m; k++ )
532         {
533             double val = cls_count[k];
534             total_weight += val;
535             if( max_val < val )
536             {
537                 max_val = val;
538                 max_k = k;
539             }
540         }
541 
542         node->class_idx = max_k;
543         node->value = classLabels[max_k];
544         node->node_risk = total_weight - max_val;
545 
546         for( j = 0; j < cv_n; j++ )
547         {
548             double sum_k = 0, sum = 0, max_val_k = 0;
549             max_val = -1; max_k = -1;
550 
551             for( k = 0; k < m; k++ )
552             {
553                 double val_k = cv_cls_count[j*m + k];
554                 double val = cls_count[k] - val_k;
555                 sum_k += val_k;
556                 sum += val;
557                 if( max_val < val )
558                 {
559                     max_val = val;
560                     max_val_k = val_k;
561                     max_k = k;
562                 }
563             }
564 
565             w->cv_Tn[nidx*cv_n + j] = INT_MAX;
566             w->cv_node_risk[nidx*cv_n + j] = sum - max_val;
567             w->cv_node_error[nidx*cv_n + j] = sum_k - max_val_k;
568         }
569     }
570     else
571     {
572         // in case of regression tree:
573         //  * node value is 1/n*sum_i(Y_i), where Y_i is i-th response,
574         //    n is the number of samples in the node.
575         //  * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2)
576         //  * j-th cross-validation fold value and risk are calculated as above,
577         //    but using the samples with cv_labels(*)!=j.
578         //  * j-th cross-validation fold error is calculated
579         //    using samples with cv_labels(*)==j as the test subset:
580         //    error_j = sum_(i,cv_labels(i)==j)((Y_i - <node_value_j>)^2),
581         //    where node_value_j is the node value calculated
582         //    as described in the previous bullet, and summation is done
583         //    over the samples with cv_labels(*)==j.
584         double sum = 0, sum2 = 0, sumw = 0;
585 
586         if( cv_n == 0 )
587         {
588             for( i = 0; i < n; i++ )
589             {
590                 int si = _sidx[i];
591                 double wval = w->sample_weights[si];
592                 double t = w->ord_responses[si];
593                 sum += t*wval;
594                 sum2 += t*t*wval;
595                 sumw += wval;
596             }
597         }
598         else
599         {
600             double *cv_sum = buf.data(), *cv_sum2 = cv_sum + cv_n;
601             double* cv_count = (double*)(cv_sum2 + cv_n);
602 
603             for( j = 0; j < cv_n; j++ )
604             {
605                 cv_sum[j] = cv_sum2[j] = 0.;
606                 cv_count[j] = 0;
607             }
608 
609             for( i = 0; i < n; i++ )
610             {
611                 int si = _sidx[i];
612                 j = w->cv_labels[si];
613                 double wval = w->sample_weights[si];
614                 double t = w->ord_responses[si];
615                 cv_sum[j] += t*wval;
616                 cv_sum2[j] += t*t*wval;
617                 cv_count[j] += wval;
618             }
619 
620             for( j = 0; j < cv_n; j++ )
621             {
622                 sum += cv_sum[j];
623                 sum2 += cv_sum2[j];
624                 sumw += cv_count[j];
625             }
626 
627             for( j = 0; j < cv_n; j++ )
628             {
629                 double s = sum - cv_sum[j], si = sum - s;
630                 double s2 = sum2 - cv_sum2[j], s2i = sum2 - s2;
631                 double c = cv_count[j], ci = sumw - c;
632                 double r = si/std::max(ci, DBL_EPSILON);
633                 w->cv_node_risk[nidx*cv_n + j] = s2i - r*r*ci;
634                 w->cv_node_error[nidx*cv_n + j] = s2 - 2*r*s + c*r*r;
635                 w->cv_Tn[nidx*cv_n + j] = INT_MAX;
636             }
637         }
638         CV_Assert(fabs(sumw) > 0);
639         node->node_risk = sum2 - (sum/sumw)*sum;
640         node->node_risk /= sumw;
641         node->value = sum/sumw;
642     }
643 }
644 
findSplitOrdClass(int vi,const vector<int> & _sidx,double initQuality)645 DTreesImpl::WSplit DTreesImpl::findSplitOrdClass( int vi, const vector<int>& _sidx, double initQuality )
646 {
647     int n = (int)_sidx.size();
648     int m = (int)classLabels.size();
649 
650     cv::AutoBuffer<uchar> buf(n*(sizeof(float) + sizeof(int)) + m*2*sizeof(double));
651     const int* sidx = &_sidx[0];
652     const int* responses = &w->cat_responses[0];
653     const double* weights = &w->sample_weights[0];
654     double* lcw = (double*)buf.data();
655     double* rcw = lcw + m;
656     float* values = (float*)(rcw + m);
657     int* sorted_idx = (int*)(values + n);
658     int i, best_i = -1;
659     double best_val = initQuality;
660 
661     for( i = 0; i < m; i++ )
662         lcw[i] = rcw[i] = 0.;
663 
664     w->data->getValues( vi, _sidx, values );
665 
666     for( i = 0; i < n; i++ )
667     {
668         sorted_idx[i] = i;
669         int si = sidx[i];
670         rcw[responses[si]] += weights[si];
671     }
672 
673     std::sort(sorted_idx, sorted_idx + n, cmp_lt_idx<float>(values));
674 
675     double L = 0, R = 0, lsum2 = 0, rsum2 = 0;
676     for( i = 0; i < m; i++ )
677     {
678         double wval = rcw[i];
679         R += wval;
680         rsum2 += wval*wval;
681     }
682 
683     for( i = 0; i < n - 1; i++ )
684     {
685         int curr = sorted_idx[i];
686         int next = sorted_idx[i+1];
687         int si = sidx[curr];
688         double wval = weights[si], w2 = wval*wval;
689         L += wval; R -= wval;
690         int idx = responses[si];
691         double lv = lcw[idx], rv = rcw[idx];
692         lsum2 += 2*lv*wval + w2;
693         rsum2 -= 2*rv*wval - w2;
694         lcw[idx] = lv + wval; rcw[idx] = rv - wval;
695 
696         float value_between = (values[next] + values[curr]) * 0.5f;
697         if( value_between > values[curr] && value_between < values[next] )
698         {
699             double val = (lsum2*R + rsum2*L)/(L*R);
700             if( best_val < val )
701             {
702                 best_val = val;
703                 best_i = i;
704             }
705         }
706     }
707 
708     WSplit split;
709     if( best_i >= 0 )
710     {
711         split.varIdx = vi;
712         split.c = (values[sorted_idx[best_i]] + values[sorted_idx[best_i+1]])*0.5f;
713         split.inversed = false;
714         split.quality = (float)best_val;
715     }
716     return split;
717 }
718 
719 // simple k-means, slightly modified to take into account the "weight" (L1-norm) of each vector.
clusterCategories(const double * vectors,int n,int m,double * csums,int k,int * labels)720 void DTreesImpl::clusterCategories( const double* vectors, int n, int m, double* csums, int k, int* labels )
721 {
722     int iters = 0, max_iters = 100;
723     int i, j, idx;
724     cv::AutoBuffer<double> buf(n + k);
725     double *v_weights = buf.data(), *c_weights = buf.data() + n;
726     bool modified = true;
727     RNG r((uint64)-1);
728 
729     // assign labels randomly
730     for( i = 0; i < n; i++ )
731     {
732         double sum = 0;
733         const double* v = vectors + i*m;
734         labels[i] = i < k ? i : r.uniform(0, k);
735 
736         // compute weight of each vector
737         for( j = 0; j < m; j++ )
738             sum += v[j];
739         v_weights[i] = sum ? 1./sum : 0.;
740     }
741 
742     for( i = 0; i < n; i++ )
743     {
744         int i1 = r.uniform(0, n);
745         int i2 = r.uniform(0, n);
746         std::swap( labels[i1], labels[i2] );
747     }
748 
749     for( iters = 0; iters <= max_iters; iters++ )
750     {
751         // calculate csums
752         for( i = 0; i < k; i++ )
753         {
754             for( j = 0; j < m; j++ )
755                 csums[i*m + j] = 0;
756         }
757 
758         for( i = 0; i < n; i++ )
759         {
760             const double* v = vectors + i*m;
761             double* s = csums + labels[i]*m;
762             for( j = 0; j < m; j++ )
763                 s[j] += v[j];
764         }
765 
766         // exit the loop here, when we have up-to-date csums
767         if( iters == max_iters || !modified )
768             break;
769 
770         modified = false;
771 
772         // calculate weight of each cluster
773         for( i = 0; i < k; i++ )
774         {
775             const double* s = csums + i*m;
776             double sum = 0;
777             for( j = 0; j < m; j++ )
778                 sum += s[j];
779             c_weights[i] = sum ? 1./sum : 0;
780         }
781 
782         // now for each vector determine the closest cluster
783         for( i = 0; i < n; i++ )
784         {
785             const double* v = vectors + i*m;
786             double alpha = v_weights[i];
787             double min_dist2 = DBL_MAX;
788             int min_idx = -1;
789 
790             for( idx = 0; idx < k; idx++ )
791             {
792                 const double* s = csums + idx*m;
793                 double dist2 = 0., beta = c_weights[idx];
794                 for( j = 0; j < m; j++ )
795                 {
796                     double t = v[j]*alpha - s[j]*beta;
797                     dist2 += t*t;
798                 }
799                 if( min_dist2 > dist2 )
800                 {
801                     min_dist2 = dist2;
802                     min_idx = idx;
803                 }
804             }
805 
806             if( min_idx != labels[i] )
807                 modified = true;
808             labels[i] = min_idx;
809         }
810     }
811 }
812 
findSplitCatClass(int vi,const vector<int> & _sidx,double initQuality,int * subset)813 DTreesImpl::WSplit DTreesImpl::findSplitCatClass( int vi, const vector<int>& _sidx,
814                                                   double initQuality, int* subset )
815 {
816     int _mi = getCatCount(vi), mi = _mi;
817     int n = (int)_sidx.size();
818     int m = (int)classLabels.size();
819 
820     int base_size = m*(3 + mi) + mi + 1;
821     if( m > 2 && mi > params.getMaxCategories() )
822         base_size += m*std::min(params.getMaxCategories(), n) + mi;
823     else
824         base_size += mi;
825     AutoBuffer<double> buf(base_size + n);
826 
827     double* lc = buf.data();
828     double* rc = lc + m;
829     double* _cjk = rc + m*2, *cjk = _cjk;
830     double* c_weights = cjk + m*mi;
831 
832     int* labels = (int*)(buf.data() + base_size);
833     w->data->getNormCatValues(vi, _sidx, labels);
834     const int* responses = &w->cat_responses[0];
835     const double* weights = &w->sample_weights[0];
836 
837     int* cluster_labels = 0;
838     double** dbl_ptr = 0;
839     int i, j, k, si, idx;
840     double L = 0, R = 0;
841     double best_val = initQuality;
842     int prevcode = 0, best_subset = -1, subset_i, subset_n, subtract = 0;
843 
844     // init array of counters:
845     // c_{jk} - number of samples that have vi-th input variable = j and response = k.
846     for( j = -1; j < mi; j++ )
847         for( k = 0; k < m; k++ )
848             cjk[j*m + k] = 0;
849 
850     for( i = 0; i < n; i++ )
851     {
852         si = _sidx[i];
853         j = labels[i];
854         k = responses[si];
855         cjk[j*m + k] += weights[si];
856     }
857 
858     if( m > 2 )
859     {
860         if( mi > params.getMaxCategories() )
861         {
862             mi = std::min(params.getMaxCategories(), n);
863             cjk = c_weights + _mi;
864             cluster_labels = (int*)(cjk + m*mi);
865             clusterCategories( _cjk, _mi, m, cjk, mi, cluster_labels );
866         }
867         subset_i = 1;
868         subset_n = 1 << mi;
869     }
870     else
871     {
872         assert( m == 2 );
873         dbl_ptr = (double**)(c_weights + _mi);
874         for( j = 0; j < mi; j++ )
875             dbl_ptr[j] = cjk + j*2 + 1;
876         std::sort(dbl_ptr, dbl_ptr + mi, cmp_lt_ptr<double>());
877         subset_i = 0;
878         subset_n = mi;
879     }
880 
881     for( k = 0; k < m; k++ )
882     {
883         double sum = 0;
884         for( j = 0; j < mi; j++ )
885             sum += cjk[j*m + k];
886         CV_Assert(sum > 0);
887         rc[k] = sum;
888         lc[k] = 0;
889     }
890 
891     for( j = 0; j < mi; j++ )
892     {
893         double sum = 0;
894         for( k = 0; k < m; k++ )
895             sum += cjk[j*m + k];
896         c_weights[j] = sum;
897         R += c_weights[j];
898     }
899 
900     for( ; subset_i < subset_n; subset_i++ )
901     {
902         double lsum2 = 0, rsum2 = 0;
903 
904         if( m == 2 )
905             idx = (int)(dbl_ptr[subset_i] - cjk)/2;
906         else
907         {
908             int graycode = (subset_i>>1)^subset_i;
909             int diff = graycode ^ prevcode;
910 
911             // determine index of the changed bit.
912             Cv32suf u;
913             idx = diff >= (1 << 16) ? 16 : 0;
914             u.f = (float)(((diff >> 16) | diff) & 65535);
915             idx += (u.i >> 23) - 127;
916             subtract = graycode < prevcode;
917             prevcode = graycode;
918         }
919 
920         double* crow = cjk + idx*m;
921         double weight = c_weights[idx];
922         if( weight < FLT_EPSILON )
923             continue;
924 
925         if( !subtract )
926         {
927             for( k = 0; k < m; k++ )
928             {
929                 double t = crow[k];
930                 double lval = lc[k] + t;
931                 double rval = rc[k] - t;
932                 lsum2 += lval*lval;
933                 rsum2 += rval*rval;
934                 lc[k] = lval; rc[k] = rval;
935             }
936             L += weight;
937             R -= weight;
938         }
939         else
940         {
941             for( k = 0; k < m; k++ )
942             {
943                 double t = crow[k];
944                 double lval = lc[k] - t;
945                 double rval = rc[k] + t;
946                 lsum2 += lval*lval;
947                 rsum2 += rval*rval;
948                 lc[k] = lval; rc[k] = rval;
949             }
950             L -= weight;
951             R += weight;
952         }
953 
954         if( L > FLT_EPSILON && R > FLT_EPSILON )
955         {
956             double val = (lsum2*R + rsum2*L)/(L*R);
957             if( best_val < val )
958             {
959                 best_val = val;
960                 best_subset = subset_i;
961             }
962         }
963     }
964 
965     WSplit split;
966     if( best_subset >= 0 )
967     {
968         split.varIdx = vi;
969         split.quality = (float)best_val;
970         memset( subset, 0, getSubsetSize(vi) * sizeof(int) );
971         if( m == 2 )
972         {
973             for( i = 0; i <= best_subset; i++ )
974             {
975                 idx = (int)(dbl_ptr[i] - cjk) >> 1;
976                 subset[idx >> 5] |= 1 << (idx & 31);
977             }
978         }
979         else
980         {
981             for( i = 0; i < _mi; i++ )
982             {
983                 idx = cluster_labels ? cluster_labels[i] : i;
984                 if( best_subset & (1 << idx) )
985                     subset[i >> 5] |= 1 << (i & 31);
986             }
987         }
988     }
989     return split;
990 }
991 
findSplitOrdReg(int vi,const vector<int> & _sidx,double initQuality)992 DTreesImpl::WSplit DTreesImpl::findSplitOrdReg( int vi, const vector<int>& _sidx, double initQuality )
993 {
994     const double* weights = &w->sample_weights[0];
995     int n = (int)_sidx.size();
996 
997     AutoBuffer<uchar> buf(n*(sizeof(int) + sizeof(float)));
998 
999     float* values = (float*)buf.data();
1000     int* sorted_idx = (int*)(values + n);
1001     w->data->getValues(vi, _sidx, values);
1002     const double* responses = &w->ord_responses[0];
1003 
1004     int i, si, best_i = -1;
1005     double L = 0, R = 0;
1006     double best_val = initQuality, lsum = 0, rsum = 0;
1007 
1008     for( i = 0; i < n; i++ )
1009     {
1010         sorted_idx[i] = i;
1011         si = _sidx[i];
1012         R += weights[si];
1013         rsum += weights[si]*responses[si];
1014     }
1015 
1016     std::sort(sorted_idx, sorted_idx + n, cmp_lt_idx<float>(values));
1017 
1018     // find the optimal split
1019     for( i = 0; i < n - 1; i++ )
1020     {
1021         int curr = sorted_idx[i];
1022         int next = sorted_idx[i+1];
1023         si = _sidx[curr];
1024         double wval = weights[si];
1025         double t = responses[si]*wval;
1026         L += wval; R -= wval;
1027         lsum += t; rsum -= t;
1028 
1029         float value_between = (values[next] + values[curr]) * 0.5f;
1030         if( value_between > values[curr] && value_between < values[next] )
1031         {
1032             double val = (lsum*lsum*R + rsum*rsum*L)/(L*R);
1033             if( best_val < val )
1034             {
1035                 best_val = val;
1036                 best_i = i;
1037             }
1038         }
1039     }
1040 
1041     WSplit split;
1042     if( best_i >= 0 )
1043     {
1044         split.varIdx = vi;
1045         split.c = (values[sorted_idx[best_i]] + values[sorted_idx[best_i+1]])*0.5f;
1046         split.inversed = false;
1047         split.quality = (float)best_val;
1048     }
1049     return split;
1050 }
1051 
findSplitCatReg(int vi,const vector<int> & _sidx,double initQuality,int * subset)1052 DTreesImpl::WSplit DTreesImpl::findSplitCatReg( int vi, const vector<int>& _sidx,
1053                                                 double initQuality, int* subset )
1054 {
1055     const double* weights = &w->sample_weights[0];
1056     const double* responses = &w->ord_responses[0];
1057     int n = (int)_sidx.size();
1058     int mi = getCatCount(vi);
1059 
1060     AutoBuffer<double> buf(3*mi + 3 + n);
1061     double* sum = buf.data() + 1;
1062     double* counts = sum + mi + 1;
1063     double** sum_ptr = (double**)(counts + mi);
1064     int* cat_labels = (int*)(sum_ptr + mi);
1065 
1066     w->data->getNormCatValues(vi, _sidx, cat_labels);
1067 
1068     double L = 0, R = 0, best_val = initQuality, lsum = 0, rsum = 0;
1069     int i, si, best_subset = -1, subset_i;
1070 
1071     for( i = -1; i < mi; i++ )
1072         sum[i] = counts[i] = 0;
1073 
1074     // calculate sum response and weight of each category of the input var
1075     for( i = 0; i < n; i++ )
1076     {
1077         int idx = cat_labels[i];
1078         si = _sidx[i];
1079         double wval = weights[si];
1080         sum[idx] += responses[si]*wval;
1081         counts[idx] += wval;
1082     }
1083 
1084     // calculate average response in each category
1085     for( i = 0; i < mi; i++ )
1086     {
1087         R += counts[i];
1088         rsum += sum[i];
1089         sum[i] = fabs(counts[i]) > DBL_EPSILON ? sum[i]/counts[i] : 0;
1090         sum_ptr[i] = sum + i;
1091     }
1092 
1093     std::sort(sum_ptr, sum_ptr + mi, cmp_lt_ptr<double>());
1094 
1095     // revert back to unnormalized sums
1096     // (there should be a very little loss in accuracy)
1097     for( i = 0; i < mi; i++ )
1098         sum[i] *= counts[i];
1099 
1100     for( subset_i = 0; subset_i < mi-1; subset_i++ )
1101     {
1102         int idx = (int)(sum_ptr[subset_i] - sum);
1103         double ni = counts[idx];
1104 
1105         if( ni > FLT_EPSILON )
1106         {
1107             double s = sum[idx];
1108             lsum += s; L += ni;
1109             rsum -= s; R -= ni;
1110 
1111             if( L > FLT_EPSILON && R > FLT_EPSILON )
1112             {
1113                 double val = (lsum*lsum*R + rsum*rsum*L)/(L*R);
1114                 if( best_val < val )
1115                 {
1116                     best_val = val;
1117                     best_subset = subset_i;
1118                 }
1119             }
1120         }
1121     }
1122 
1123     WSplit split;
1124     if( best_subset >= 0 )
1125     {
1126         split.varIdx = vi;
1127         split.quality = (float)best_val;
1128         memset( subset, 0, getSubsetSize(vi) * sizeof(int));
1129         for( i = 0; i <= best_subset; i++ )
1130         {
1131             int idx = (int)(sum_ptr[i] - sum);
1132             subset[idx >> 5] |= 1 << (idx & 31);
1133         }
1134     }
1135     return split;
1136 }
1137 
calcDir(int splitidx,const vector<int> & _sidx,vector<int> & _sleft,vector<int> & _sright)1138 int DTreesImpl::calcDir( int splitidx, const vector<int>& _sidx,
1139                          vector<int>& _sleft, vector<int>& _sright )
1140 {
1141     WSplit split = w->wsplits[splitidx];
1142     int i, si, n = (int)_sidx.size(), vi = split.varIdx;
1143     _sleft.reserve(n);
1144     _sright.reserve(n);
1145     _sleft.clear();
1146     _sright.clear();
1147 
1148     AutoBuffer<float> buf(n);
1149     int mi = getCatCount(vi);
1150     double wleft = 0, wright = 0;
1151     const double* weights = &w->sample_weights[0];
1152 
1153     if( mi <= 0 ) // split on an ordered variable
1154     {
1155         float c = split.c;
1156         float* values = buf.data();
1157         w->data->getValues(vi, _sidx, values);
1158 
1159         for( i = 0; i < n; i++ )
1160         {
1161             si = _sidx[i];
1162             if( values[i] <= c )
1163             {
1164                 _sleft.push_back(si);
1165                 wleft += weights[si];
1166             }
1167             else
1168             {
1169                 _sright.push_back(si);
1170                 wright += weights[si];
1171             }
1172         }
1173     }
1174     else
1175     {
1176         const int* subset = &w->wsubsets[split.subsetOfs];
1177         int* cat_labels = (int*)buf.data();
1178         w->data->getNormCatValues(vi, _sidx, cat_labels);
1179 
1180         for( i = 0; i < n; i++ )
1181         {
1182             si = _sidx[i];
1183             unsigned u = cat_labels[i];
1184             if( CV_DTREE_CAT_DIR(u, subset) < 0 )
1185             {
1186                 _sleft.push_back(si);
1187                 wleft += weights[si];
1188             }
1189             else
1190             {
1191                 _sright.push_back(si);
1192                 wright += weights[si];
1193             }
1194         }
1195     }
1196     CV_Assert( (int)_sleft.size() < n && (int)_sright.size() < n );
1197     return wleft > wright ? -1 : 1;
1198 }
1199 
pruneCV(int root)1200 int DTreesImpl::pruneCV( int root )
1201 {
1202     vector<double> ab;
1203 
1204     // 1. build tree sequence for each cv fold, calculate error_{Tj,beta_k}.
1205     // 2. choose the best tree index (if need, apply 1SE rule).
1206     // 3. store the best index and cut the branches.
1207 
1208     int ti, tree_count = 0, j, cv_n = params.getCVFolds(), n = w->wnodes[root].sample_count;
1209     // currently, 1SE for regression is not implemented
1210     bool use_1se = params.use1SERule != 0 && _isClassifier;
1211     double min_err = 0, min_err_se = 0;
1212     int min_idx = -1;
1213 
1214     // build the main tree sequence, calculate alpha's
1215     for(;;tree_count++)
1216     {
1217         double min_alpha = updateTreeRNC(root, tree_count, -1);
1218         if( cutTree(root, tree_count, -1, min_alpha) )
1219             break;
1220 
1221         ab.push_back(min_alpha);
1222     }
1223 
1224     if( tree_count > 0 )
1225     {
1226         ab[0] = 0.;
1227 
1228         for( ti = 1; ti < tree_count-1; ti++ )
1229             ab[ti] = std::sqrt(ab[ti]*ab[ti+1]);
1230         ab[tree_count-1] = DBL_MAX*0.5;
1231 
1232         Mat err_jk(cv_n, tree_count, CV_64F);
1233 
1234         for( j = 0; j < cv_n; j++ )
1235         {
1236             int tj = 0, tk = 0;
1237             for( ; tj < tree_count; tj++ )
1238             {
1239                 double min_alpha = updateTreeRNC(root, tj, j);
1240                 if( cutTree(root, tj, j, min_alpha) )
1241                     min_alpha = DBL_MAX;
1242 
1243                 for( ; tk < tree_count; tk++ )
1244                 {
1245                     if( ab[tk] > min_alpha )
1246                         break;
1247                     err_jk.at<double>(j, tk) = w->wnodes[root].tree_error;
1248                 }
1249             }
1250         }
1251 
1252         for( ti = 0; ti < tree_count; ti++ )
1253         {
1254             double sum_err = 0;
1255             for( j = 0; j < cv_n; j++ )
1256                 sum_err += err_jk.at<double>(j, ti);
1257             if( ti == 0 || sum_err < min_err )
1258             {
1259                 min_err = sum_err;
1260                 min_idx = ti;
1261                 if( use_1se )
1262                     min_err_se = sqrt( sum_err*(n - sum_err) );
1263             }
1264             else if( sum_err < min_err + min_err_se )
1265                 min_idx = ti;
1266         }
1267     }
1268 
1269     return min_idx;
1270 }
1271 
updateTreeRNC(int root,double T,int fold)1272 double DTreesImpl::updateTreeRNC( int root, double T, int fold )
1273 {
1274     int nidx = root, pidx = -1, cv_n = params.getCVFolds();
1275     double min_alpha = DBL_MAX;
1276 
1277     for(;;)
1278     {
1279         WNode *node = 0, *parent = 0;
1280 
1281         for(;;)
1282         {
1283             node = &w->wnodes[nidx];
1284             double t = fold >= 0 ? w->cv_Tn[nidx*cv_n + fold] : node->Tn;
1285             if( t <= T || node->left < 0 )
1286             {
1287                 node->complexity = 1;
1288                 node->tree_risk = node->node_risk;
1289                 node->tree_error = 0.;
1290                 if( fold >= 0 )
1291                 {
1292                     node->tree_risk = w->cv_node_risk[nidx*cv_n + fold];
1293                     node->tree_error = w->cv_node_error[nidx*cv_n + fold];
1294                 }
1295                 break;
1296             }
1297             nidx = node->left;
1298         }
1299 
1300         for( pidx = node->parent; pidx >= 0 && w->wnodes[pidx].right == nidx;
1301              nidx = pidx, pidx = w->wnodes[pidx].parent )
1302         {
1303             node = &w->wnodes[nidx];
1304             parent = &w->wnodes[pidx];
1305             parent->complexity += node->complexity;
1306             parent->tree_risk += node->tree_risk;
1307             parent->tree_error += node->tree_error;
1308 
1309             parent->alpha = ((fold >= 0 ? w->cv_node_risk[pidx*cv_n + fold] : parent->node_risk)
1310                              - parent->tree_risk)/(parent->complexity - 1);
1311             min_alpha = std::min( min_alpha, parent->alpha );
1312         }
1313 
1314         if( pidx < 0 )
1315             break;
1316 
1317         node = &w->wnodes[nidx];
1318         parent = &w->wnodes[pidx];
1319         parent->complexity = node->complexity;
1320         parent->tree_risk = node->tree_risk;
1321         parent->tree_error = node->tree_error;
1322         nidx = parent->right;
1323     }
1324 
1325     return min_alpha;
1326 }
1327 
cutTree(int root,double T,int fold,double min_alpha)1328 bool DTreesImpl::cutTree( int root, double T, int fold, double min_alpha )
1329 {
1330     int cv_n = params.getCVFolds(), nidx = root, pidx = -1;
1331     WNode* node = &w->wnodes[root];
1332     if( node->left < 0 )
1333         return true;
1334 
1335     for(;;)
1336     {
1337         for(;;)
1338         {
1339             node = &w->wnodes[nidx];
1340             double t = fold >= 0 ? w->cv_Tn[nidx*cv_n + fold] : node->Tn;
1341             if( t <= T || node->left < 0 )
1342                 break;
1343             if( node->alpha <= min_alpha + FLT_EPSILON )
1344             {
1345                 if( fold >= 0 )
1346                     w->cv_Tn[nidx*cv_n + fold] = T;
1347                 else
1348                     node->Tn = T;
1349                 if( nidx == root )
1350                     return true;
1351                 break;
1352             }
1353             nidx = node->left;
1354         }
1355 
1356         for( pidx = node->parent; pidx >= 0 && w->wnodes[pidx].right == nidx;
1357              nidx = pidx, pidx = w->wnodes[pidx].parent )
1358             ;
1359 
1360         if( pidx < 0 )
1361             break;
1362 
1363         nidx = w->wnodes[pidx].right;
1364     }
1365 
1366     return false;
1367 }
1368 
predictTrees(const Range & range,const Mat & sample,int flags) const1369 float DTreesImpl::predictTrees( const Range& range, const Mat& sample, int flags ) const
1370 {
1371     CV_Assert( sample.type() == CV_32F );
1372 
1373     int predictType = flags & PREDICT_MASK;
1374     int nvars = (int)varIdx.size();
1375     if( nvars == 0 )
1376         nvars = (int)varType.size();
1377     int i, ncats = (int)catOfs.size(), nclasses = (int)classLabels.size();
1378     int catbufsize = ncats > 0 ? nvars : 0;
1379     AutoBuffer<int> buf(nclasses + catbufsize + 1);
1380     int* votes = buf.data();
1381     int* catbuf = votes + nclasses;
1382     const int* cvidx = (flags & (COMPRESSED_INPUT|PREPROCESSED_INPUT)) == 0 && !varIdx.empty() ? &compVarIdx[0] : 0;
1383     const uchar* vtype = &varType[0];
1384     const Vec2i* cofs = !catOfs.empty() ? &catOfs[0] : 0;
1385     const int* cmap = !catMap.empty() ? &catMap[0] : 0;
1386     const float* psample = sample.ptr<float>();
1387     const float* missingSubstPtr = !missingSubst.empty() ? &missingSubst[0] : 0;
1388     size_t sstep = sample.isContinuous() ? 1 : sample.step/sizeof(float);
1389     double sum = 0.;
1390     int lastClassIdx = -1;
1391     const float MISSED_VAL = TrainData::missingValue();
1392 
1393     for( i = 0; i < catbufsize; i++ )
1394         catbuf[i] = -1;
1395 
1396     if( predictType == PREDICT_AUTO )
1397     {
1398         predictType = !_isClassifier || (classLabels.size() == 2 && (flags & RAW_OUTPUT) != 0) ?
1399             PREDICT_SUM : PREDICT_MAX_VOTE;
1400     }
1401 
1402     if( predictType == PREDICT_MAX_VOTE )
1403     {
1404         for( i = 0; i < nclasses; i++ )
1405             votes[i] = 0;
1406     }
1407 
1408     for( int ridx = range.start; ridx < range.end; ridx++ )
1409     {
1410         int nidx = roots[ridx], prev = nidx, c = 0;
1411 
1412         for(;;)
1413         {
1414             prev = nidx;
1415             const Node& node = nodes[nidx];
1416             if( node.split < 0 )
1417                 break;
1418             const Split& split = splits[node.split];
1419             int vi = split.varIdx;
1420             int ci = cvidx ? cvidx[vi] : vi;
1421             float val = psample[ci*sstep];
1422             if( val == MISSED_VAL )
1423             {
1424                 if( !missingSubstPtr )
1425                 {
1426                     nidx = node.defaultDir < 0 ? node.left : node.right;
1427                     continue;
1428                 }
1429                 val = missingSubstPtr[vi];
1430             }
1431 
1432             if( vtype[vi] == VAR_ORDERED )
1433                 nidx = val <= split.c ? node.left : node.right;
1434             else
1435             {
1436                 if( flags & PREPROCESSED_INPUT )
1437                     c = cvRound(val);
1438                 else
1439                 {
1440                     c = catbuf[ci];
1441                     if( c < 0 )
1442                     {
1443                         int a = c = cofs[vi][0];
1444                         int b = cofs[vi][1];
1445 
1446                         int ival = cvRound(val);
1447                         if( ival != val )
1448                             CV_Error( CV_StsBadArg,
1449                                      "one of input categorical variable is not an integer" );
1450 
1451                         CV_Assert(cmap != NULL);
1452                         while( a < b )
1453                         {
1454                             c = (a + b) >> 1;
1455                             if( ival < cmap[c] )
1456                                 b = c;
1457                             else if( ival > cmap[c] )
1458                                 a = c+1;
1459                             else
1460                                 break;
1461                         }
1462 
1463                         CV_Assert( c >= 0 && ival == cmap[c] );
1464 
1465                         c -= cofs[vi][0];
1466                         catbuf[ci] = c;
1467                     }
1468                     const int* subset = &subsets[split.subsetOfs];
1469                     unsigned u = c;
1470                     nidx = CV_DTREE_CAT_DIR(u, subset) < 0 ? node.left : node.right;
1471                 }
1472             }
1473         }
1474 
1475         if( predictType == PREDICT_SUM )
1476             sum += nodes[prev].value;
1477         else
1478         {
1479             lastClassIdx = nodes[prev].classIdx;
1480             votes[lastClassIdx]++;
1481         }
1482     }
1483 
1484     if( predictType == PREDICT_MAX_VOTE )
1485     {
1486         int best_idx = lastClassIdx;
1487         if( range.end - range.start > 1 )
1488         {
1489             best_idx = 0;
1490             for( i = 1; i < nclasses; i++ )
1491                 if( votes[best_idx] < votes[i] )
1492                     best_idx = i;
1493         }
1494         sum = (flags & RAW_OUTPUT) ? (float)best_idx : classLabels[best_idx];
1495     }
1496 
1497     return (float)sum;
1498 }
1499 
1500 
predict(InputArray _samples,OutputArray _results,int flags) const1501 float DTreesImpl::predict( InputArray _samples, OutputArray _results, int flags ) const
1502 {
1503     CV_Assert( !roots.empty() );
1504     Mat samples = _samples.getMat(), results;
1505     int i, nsamples = samples.rows;
1506     int rtype = CV_32F;
1507     bool needresults = _results.needed();
1508     float retval = 0.f;
1509     bool iscls = isClassifier();
1510     float scale = !iscls ? 1.f/(int)roots.size() : 1.f;
1511 
1512     if( iscls && (flags & PREDICT_MASK) == PREDICT_MAX_VOTE )
1513         rtype = CV_32S;
1514 
1515     if( needresults )
1516     {
1517         _results.create(nsamples, 1, rtype);
1518         results = _results.getMat();
1519     }
1520     else
1521         nsamples = std::min(nsamples, 1);
1522 
1523     for( i = 0; i < nsamples; i++ )
1524     {
1525         float val = predictTrees( Range(0, (int)roots.size()), samples.row(i), flags )*scale;
1526         if( needresults )
1527         {
1528             if( rtype == CV_32F )
1529                 results.at<float>(i) = val;
1530             else
1531                 results.at<int>(i) = cvRound(val);
1532         }
1533         if( i == 0 )
1534             retval = val;
1535     }
1536     return retval;
1537 }
1538 
writeTrainingParams(FileStorage & fs) const1539 void DTreesImpl::writeTrainingParams(FileStorage& fs) const
1540 {
1541     fs << "use_surrogates" << (params.useSurrogates ? 1 : 0);
1542     fs << "max_categories" << params.getMaxCategories();
1543     fs << "regression_accuracy" << params.getRegressionAccuracy();
1544 
1545     fs << "max_depth" << params.getMaxDepth();
1546     fs << "min_sample_count" << params.getMinSampleCount();
1547     fs << "cross_validation_folds" << params.getCVFolds();
1548 
1549     if( params.getCVFolds() > 1 )
1550         fs << "use_1se_rule" << (params.use1SERule ? 1 : 0);
1551 
1552     if( !params.priors.empty() )
1553         fs << "priors" << params.priors;
1554 }
1555 
writeParams(FileStorage & fs) const1556 void DTreesImpl::writeParams(FileStorage& fs) const
1557 {
1558     fs << "is_classifier" << isClassifier();
1559     fs << "var_all" << (int)varType.size();
1560     fs << "var_count" << getVarCount();
1561 
1562     int ord_var_count = 0, cat_var_count = 0;
1563     int i, n = (int)varType.size();
1564     for( i = 0; i < n; i++ )
1565         if( varType[i] == VAR_ORDERED )
1566             ord_var_count++;
1567         else
1568             cat_var_count++;
1569     fs << "ord_var_count" << ord_var_count;
1570     fs << "cat_var_count" << cat_var_count;
1571 
1572     fs << "training_params" << "{";
1573     writeTrainingParams(fs);
1574 
1575     fs << "}";
1576 
1577     if( !varIdx.empty() )
1578     {
1579         fs << "global_var_idx" << 1;
1580         fs << "var_idx" << varIdx;
1581     }
1582 
1583     fs << "var_type" << varType;
1584 
1585     if( !catOfs.empty() )
1586         fs << "cat_ofs" << catOfs;
1587     if( !catMap.empty() )
1588         fs << "cat_map" << catMap;
1589     if( !classLabels.empty() )
1590         fs << "class_labels" << classLabels;
1591     if( !missingSubst.empty() )
1592         fs << "missing_subst" << missingSubst;
1593 }
1594 
writeSplit(FileStorage & fs,int splitidx) const1595 void DTreesImpl::writeSplit( FileStorage& fs, int splitidx ) const
1596 {
1597     const Split& split = splits[splitidx];
1598 
1599     fs << "{:";
1600 
1601     int vi = split.varIdx;
1602     fs << "var" << vi;
1603     fs << "quality" << split.quality;
1604 
1605     if( varType[vi] == VAR_CATEGORICAL ) // split on a categorical var
1606     {
1607         int i, n = getCatCount(vi), to_right = 0;
1608         const int* subset = &subsets[split.subsetOfs];
1609         for( i = 0; i < n; i++ )
1610             to_right += CV_DTREE_CAT_DIR(i, subset) > 0;
1611 
1612         // ad-hoc rule when to use inverse categorical split notation
1613         // to achieve more compact and clear representation
1614         int default_dir = to_right <= 1 || to_right <= std::min(3, n/2) || to_right <= n/3 ? -1 : 1;
1615 
1616         fs << (default_dir*(split.inversed ? -1 : 1) > 0 ? "in" : "not_in") << "[:";
1617 
1618         for( i = 0; i < n; i++ )
1619         {
1620             int dir = CV_DTREE_CAT_DIR(i, subset);
1621             if( dir*default_dir < 0 )
1622                 fs << i;
1623         }
1624 
1625         fs << "]";
1626     }
1627     else
1628         fs << (!split.inversed ? "le" : "gt") << split.c;
1629 
1630     fs << "}";
1631 }
1632 
writeNode(FileStorage & fs,int nidx,int depth) const1633 void DTreesImpl::writeNode( FileStorage& fs, int nidx, int depth ) const
1634 {
1635     const Node& node = nodes[nidx];
1636     fs << "{";
1637     fs << "depth" << depth;
1638     fs << "value" << node.value;
1639 
1640     if( _isClassifier )
1641         fs << "norm_class_idx" << node.classIdx;
1642 
1643     if( node.split >= 0 )
1644     {
1645         fs << "splits" << "[";
1646 
1647         for( int splitidx = node.split; splitidx >= 0; splitidx = splits[splitidx].next )
1648             writeSplit( fs, splitidx );
1649 
1650         fs << "]";
1651     }
1652 
1653     fs << "}";
1654 }
1655 
writeTree(FileStorage & fs,int root) const1656 void DTreesImpl::writeTree( FileStorage& fs, int root ) const
1657 {
1658     fs << "nodes" << "[";
1659 
1660     int nidx = root, pidx = 0, depth = 0;
1661     const Node *node = 0;
1662 
1663     // traverse the tree and save all the nodes in depth-first order
1664     for(;;)
1665     {
1666         for(;;)
1667         {
1668             writeNode( fs, nidx, depth );
1669             node = &nodes[nidx];
1670             if( node->left < 0 )
1671                 break;
1672             nidx = node->left;
1673             depth++;
1674         }
1675 
1676         for( pidx = node->parent; pidx >= 0 && nodes[pidx].right == nidx;
1677              nidx = pidx, pidx = nodes[pidx].parent )
1678             depth--;
1679 
1680         if( pidx < 0 )
1681             break;
1682 
1683         nidx = nodes[pidx].right;
1684     }
1685 
1686     fs << "]";
1687 }
1688 
write(FileStorage & fs) const1689 void DTreesImpl::write( FileStorage& fs ) const
1690 {
1691     writeFormat(fs);
1692     writeParams(fs);
1693     writeTree(fs, roots[0]);
1694 }
1695 
readParams(const FileNode & fn)1696 void DTreesImpl::readParams( const FileNode& fn )
1697 {
1698     _isClassifier = (int)fn["is_classifier"] != 0;
1699     int varAll = (int)fn["var_all"];
1700     int varCount = (int)fn["var_count"];
1701     /*int cat_var_count = (int)fn["cat_var_count"];
1702     int ord_var_count = (int)fn["ord_var_count"];*/
1703 
1704     if (varAll <= 0)
1705         CV_Error(Error::StsParseError, "The field \"var_all\" of DTree classifier is missing or non-positive");
1706 
1707     FileNode tparams_node = fn["training_params"];
1708 
1709     TreeParams params0 = TreeParams();
1710 
1711     if( !tparams_node.empty() ) // training parameters are not necessary
1712     {
1713         params0.useSurrogates = (int)tparams_node["use_surrogates"] != 0;
1714         params0.setMaxCategories((int)(tparams_node["max_categories"].empty() ? 16 : tparams_node["max_categories"]));
1715         params0.setRegressionAccuracy((float)tparams_node["regression_accuracy"]);
1716         params0.setMaxDepth((int)tparams_node["max_depth"]);
1717         params0.setMinSampleCount((int)tparams_node["min_sample_count"]);
1718         params0.setCVFolds((int)tparams_node["cross_validation_folds"]);
1719 
1720         if( params0.getCVFolds() > 1 )
1721         {
1722             params.use1SERule = (int)tparams_node["use_1se_rule"] != 0;
1723         }
1724 
1725         tparams_node["priors"] >> params0.priors;
1726     }
1727 
1728     readVectorOrMat(fn["var_idx"], varIdx);
1729     fn["var_type"] >> varType;
1730 
1731     bool isLegacy = false;
1732     if (fn["format"].empty())  // Export bug until OpenCV 3.2: https://github.com/opencv/opencv/pull/6314
1733     {
1734         if (!fn["cat_ofs"].empty())
1735             isLegacy = false;  // 2.4 doesn't store "cat_ofs"
1736         else if (!fn["missing_subst"].empty())
1737             isLegacy = false;  // 2.4 doesn't store "missing_subst"
1738         else if (!fn["class_labels"].empty())
1739             isLegacy = false;  // 2.4 doesn't store "class_labels"
1740         else if ((int)varType.size() != varAll)
1741             isLegacy = true;  // 3.0+: https://github.com/opencv/opencv/blame/3.0.0/modules/ml/src/tree.cpp#L1576
1742         else if (/*(int)varType.size() == varAll &&*/ varCount == varAll)
1743             isLegacy = true;
1744         else
1745         {
1746             // 3.0+:
1747             // - https://github.com/opencv/opencv/blame/3.0.0/modules/ml/src/tree.cpp#L1552-L1553
1748             // - https://github.com/opencv/opencv/blame/3.0.0/modules/ml/src/precomp.hpp#L296
1749             isLegacy = !(varCount + 1 == varAll);
1750         }
1751         CV_LOG_INFO(NULL, "ML/DTrees: possible missing 'format' field due to bug of OpenCV export implementation. "
1752                 "Details: https://github.com/opencv/opencv/issues/5412. Consider re-exporting of saved ML model. "
1753                 "isLegacy = " << isLegacy);
1754     }
1755     else
1756     {
1757         int format = 0;
1758         fn["format"] >> format;
1759         CV_CheckGT(format, 0, "");
1760         isLegacy = format < 3;
1761     }
1762 
1763     if (isLegacy && (int)varType.size() <= varAll)
1764     {
1765         std::vector<uchar> extendedTypes(varAll + 1, 0);
1766 
1767         int i = 0, n;
1768         if (!varIdx.empty())
1769         {
1770             n = (int)varIdx.size();
1771             for (; i < n; ++i)
1772             {
1773                 int var = varIdx[i];
1774                 extendedTypes[var] = varType[i];
1775             }
1776         }
1777         else
1778         {
1779             n = (int)varType.size();
1780             for (; i < n; ++i)
1781             {
1782                 extendedTypes[i] = varType[i];
1783             }
1784         }
1785         extendedTypes[varAll] = (uchar)(_isClassifier ? VAR_CATEGORICAL : VAR_ORDERED);
1786         extendedTypes.swap(varType);
1787     }
1788 
1789     readVectorOrMat(fn["cat_map"], catMap);
1790 
1791     if (isLegacy)
1792     {
1793         // generating "catOfs" from "cat_count"
1794         catOfs.clear();
1795         classLabels.clear();
1796         std::vector<int> counts;
1797         readVectorOrMat(fn["cat_count"], counts);
1798         unsigned int i = 0, j = 0, curShift = 0, size = (int)varType.size() - 1;
1799         for (; i < size; ++i)
1800         {
1801             Vec2i newOffsets(0, 0);
1802             if (varType[i] == VAR_CATEGORICAL) // only categorical vars are represented in catMap
1803             {
1804                 newOffsets[0] = curShift;
1805                 curShift += counts[j];
1806                 newOffsets[1] = curShift;
1807                 ++j;
1808             }
1809             catOfs.push_back(newOffsets);
1810         }
1811         // other elements in "catMap" are "classLabels"
1812         if (curShift < catMap.size())
1813         {
1814             classLabels.insert(classLabels.end(), catMap.begin() + curShift, catMap.end());
1815             catMap.erase(catMap.begin() + curShift, catMap.end());
1816         }
1817     }
1818     else
1819     {
1820         fn["cat_ofs"] >> catOfs;
1821         fn["missing_subst"] >> missingSubst;
1822         fn["class_labels"] >> classLabels;
1823     }
1824 
1825     // init var mapping for node reading (var indexes or varIdx indexes)
1826     bool globalVarIdx = false;
1827     fn["global_var_idx"] >> globalVarIdx;
1828     if (globalVarIdx || varIdx.empty())
1829         setRangeVector(varMapping, (int)varType.size());
1830     else
1831         varMapping = varIdx;
1832 
1833     initCompVarIdx();
1834     setDParams(params0);
1835 }
1836 
readSplit(const FileNode & fn)1837 int DTreesImpl::readSplit( const FileNode& fn )
1838 {
1839     Split split;
1840 
1841     int vi = (int)fn["var"];
1842     CV_Assert( 0 <= vi && vi <= (int)varType.size() );
1843     vi = varMapping[vi]; // convert to varIdx if needed
1844     split.varIdx = vi;
1845 
1846     if( varType[vi] == VAR_CATEGORICAL ) // split on categorical var
1847     {
1848         int i, val, ssize = getSubsetSize(vi);
1849         split.subsetOfs = (int)subsets.size();
1850         for( i = 0; i < ssize; i++ )
1851             subsets.push_back(0);
1852         int* subset = &subsets[split.subsetOfs];
1853         FileNode fns = fn["in"];
1854         if( fns.empty() )
1855         {
1856             fns = fn["not_in"];
1857             split.inversed = true;
1858         }
1859 
1860         if( fns.isInt() )
1861         {
1862             val = (int)fns;
1863             subset[val >> 5] |= 1 << (val & 31);
1864         }
1865         else
1866         {
1867             FileNodeIterator it = fns.begin();
1868             int n = (int)fns.size();
1869             for( i = 0; i < n; i++, ++it )
1870             {
1871                 val = (int)*it;
1872                 subset[val >> 5] |= 1 << (val & 31);
1873             }
1874         }
1875 
1876         // for categorical splits we do not use inversed splits,
1877         // instead we inverse the variable set in the split
1878         if( split.inversed )
1879         {
1880             for( i = 0; i < ssize; i++ )
1881                 subset[i] ^= -1;
1882             split.inversed = false;
1883         }
1884     }
1885     else
1886     {
1887         FileNode cmpNode = fn["le"];
1888         if( cmpNode.empty() )
1889         {
1890             cmpNode = fn["gt"];
1891             split.inversed = true;
1892         }
1893         split.c = (float)cmpNode;
1894     }
1895 
1896     split.quality = (float)fn["quality"];
1897     splits.push_back(split);
1898 
1899     return (int)(splits.size() - 1);
1900 }
1901 
readNode(const FileNode & fn)1902 int DTreesImpl::readNode( const FileNode& fn )
1903 {
1904     Node node;
1905     node.value = (double)fn["value"];
1906 
1907     if( _isClassifier )
1908         node.classIdx = (int)fn["norm_class_idx"];
1909 
1910     FileNode sfn = fn["splits"];
1911     if( !sfn.empty() )
1912     {
1913         int i, n = (int)sfn.size(), prevsplit = -1;
1914         FileNodeIterator it = sfn.begin();
1915 
1916         for( i = 0; i < n; i++, ++it )
1917         {
1918             int splitidx = readSplit(*it);
1919             if( splitidx < 0 )
1920                 break;
1921             if( prevsplit < 0 )
1922                 node.split = splitidx;
1923             else
1924                 splits[prevsplit].next = splitidx;
1925             prevsplit = splitidx;
1926         }
1927     }
1928     nodes.push_back(node);
1929     return (int)(nodes.size() - 1);
1930 }
1931 
readTree(const FileNode & fn)1932 int DTreesImpl::readTree( const FileNode& fn )
1933 {
1934     int i, n = (int)fn.size(), root = -1, pidx = -1;
1935     FileNodeIterator it = fn.begin();
1936 
1937     for( i = 0; i < n; i++, ++it )
1938     {
1939         int nidx = readNode(*it);
1940         if( nidx < 0 )
1941             break;
1942         Node& node = nodes[nidx];
1943         node.parent = pidx;
1944         if( pidx < 0 )
1945             root = nidx;
1946         else
1947         {
1948             Node& parent = nodes[pidx];
1949             if( parent.left < 0 )
1950                 parent.left = nidx;
1951             else
1952                 parent.right = nidx;
1953         }
1954         if( node.split >= 0 )
1955             pidx = nidx;
1956         else
1957         {
1958             while( pidx >= 0 && nodes[pidx].right >= 0 )
1959                 pidx = nodes[pidx].parent;
1960         }
1961     }
1962     roots.push_back(root);
1963     return root;
1964 }
1965 
read(const FileNode & fn)1966 void DTreesImpl::read( const FileNode& fn )
1967 {
1968     clear();
1969     readParams(fn);
1970 
1971     FileNode fnodes = fn["nodes"];
1972     CV_Assert( !fnodes.empty() );
1973     readTree(fnodes);
1974 }
1975 
create()1976 Ptr<DTrees> DTrees::create()
1977 {
1978     return makePtr<DTreesImpl>();
1979 }
1980 
load(const String & filepath,const String & nodeName)1981 Ptr<DTrees> DTrees::load(const String& filepath, const String& nodeName)
1982 {
1983     return Algorithm::load<DTrees>(filepath, nodeName);
1984 }
1985 
1986 
1987 }
1988 }
1989 
1990 /* End of file. */
1991