1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 // By downloading, copying, installing or using the software you agree to this license.
6 // If you do not agree to this license, do not download, install,
7 // copy or use the software.
8 //
9 //
10 // License Agreement
11 // For Open Source Computer Vision Library
12 //
13 // Copyright (C) 2000, Intel Corporation, all rights reserved.
14 // Copyright (C) 2014, Itseez Inc, all rights reserved.
15 // Third party copyrights are property of their respective owners.
16 //
17 // Redistribution and use in source and binary forms, with or without modification,
18 // are permitted provided that the following conditions are met:
19 //
20 // * Redistribution's of source code must retain the above copyright notice,
21 // this list of conditions and the following disclaimer.
22 //
23 // * Redistribution's in binary form must reproduce the above copyright notice,
24 // this list of conditions and the following disclaimer in the documentation
25 // and/or other materials provided with the distribution.
26 //
27 // * The name of the copyright holders may not be used to endorse or promote products
28 // derived from this software without specific prior written permission.
29 //
30 // This software is provided by the copyright holders and contributors "as is" and
31 // any express or implied warranties, including, but not limited to, the implied
32 // warranties of merchantability and fitness for a particular purpose are disclaimed.
33 // In no event shall the Intel Corporation or contributors be liable for any direct,
34 // indirect, incidental, special, exemplary, or consequential damages
35 // (including, but not limited to, procurement of substitute goods or services;
36 // loss of use, data, or profits; or business interruption) however caused
37 // and on any theory of liability, whether in contract, strict liability,
38 // or tort (including negligence or otherwise) arising in any way out of
39 // the use of this software, even if advised of the possibility of such damage.
40 //
41 //M*/
42
43 #include "precomp.hpp"
44 #include <ctype.h>
45
46 #include <opencv2/core/utils/logger.hpp>
47
48 namespace cv {
49 namespace ml {
50
51 using std::vector;
52
TreeParams()53 TreeParams::TreeParams()
54 {
55 maxDepth = INT_MAX;
56 minSampleCount = 10;
57 regressionAccuracy = 0.01f;
58 useSurrogates = false;
59 maxCategories = 10;
60 CVFolds = 10;
61 use1SERule = true;
62 truncatePrunedTree = true;
63 priors = Mat();
64 }
65
TreeParams(int _maxDepth,int _minSampleCount,double _regressionAccuracy,bool _useSurrogates,int _maxCategories,int _CVFolds,bool _use1SERule,bool _truncatePrunedTree,const Mat & _priors)66 TreeParams::TreeParams(int _maxDepth, int _minSampleCount,
67 double _regressionAccuracy, bool _useSurrogates,
68 int _maxCategories, int _CVFolds,
69 bool _use1SERule, bool _truncatePrunedTree,
70 const Mat& _priors)
71 {
72 maxDepth = _maxDepth;
73 minSampleCount = _minSampleCount;
74 regressionAccuracy = (float)_regressionAccuracy;
75 useSurrogates = _useSurrogates;
76 maxCategories = _maxCategories;
77 CVFolds = _CVFolds;
78 use1SERule = _use1SERule;
79 truncatePrunedTree = _truncatePrunedTree;
80 priors = _priors;
81 }
82
Node()83 DTrees::Node::Node()
84 {
85 classIdx = 0;
86 value = 0;
87 parent = left = right = split = defaultDir = -1;
88 }
89
Split()90 DTrees::Split::Split()
91 {
92 varIdx = 0;
93 inversed = false;
94 quality = 0.f;
95 next = -1;
96 c = 0.f;
97 subsetOfs = 0;
98 }
99
100
WorkData(const Ptr<TrainData> & _data)101 DTreesImpl::WorkData::WorkData(const Ptr<TrainData>& _data)
102 {
103 CV_Assert(!_data.empty());
104 data = _data;
105 vector<int> subsampleIdx;
106 Mat sidx0 = _data->getTrainSampleIdx();
107 if( !sidx0.empty() )
108 {
109 sidx0.copyTo(sidx);
110 std::sort(sidx.begin(), sidx.end());
111 }
112 else
113 {
114 int n = _data->getNSamples();
115 setRangeVector(sidx, n);
116 }
117
118 maxSubsetSize = 0;
119 }
120
DTreesImpl()121 DTreesImpl::DTreesImpl() : _isClassifier(false) {}
~DTreesImpl()122 DTreesImpl::~DTreesImpl() {}
clear()123 void DTreesImpl::clear()
124 {
125 varIdx.clear();
126 compVarIdx.clear();
127 varType.clear();
128 catOfs.clear();
129 catMap.clear();
130 roots.clear();
131 nodes.clear();
132 splits.clear();
133 subsets.clear();
134 classLabels.clear();
135
136 w.release();
137 _isClassifier = false;
138 }
139
startTraining(const Ptr<TrainData> & data,int)140 void DTreesImpl::startTraining( const Ptr<TrainData>& data, int )
141 {
142 CV_Assert(!data.empty());
143 clear();
144 w = makePtr<WorkData>(data);
145
146 Mat vtype = data->getVarType();
147 vtype.copyTo(varType);
148
149 data->getCatOfs().copyTo(catOfs);
150 data->getCatMap().copyTo(catMap);
151 data->getDefaultSubstValues().copyTo(missingSubst);
152
153 int nallvars = data->getNAllVars();
154
155 Mat vidx0 = data->getVarIdx();
156 if( !vidx0.empty() )
157 vidx0.copyTo(varIdx);
158 else
159 setRangeVector(varIdx, nallvars);
160
161 initCompVarIdx();
162
163 w->maxSubsetSize = 0;
164
165 int i, nvars = (int)varIdx.size();
166 for( i = 0; i < nvars; i++ )
167 w->maxSubsetSize = std::max(w->maxSubsetSize, getCatCount(varIdx[i]));
168
169 w->maxSubsetSize = std::max((w->maxSubsetSize + 31)/32, 1);
170
171 data->getSampleWeights().copyTo(w->sample_weights);
172
173 _isClassifier = data->getResponseType() == VAR_CATEGORICAL;
174
175 if( _isClassifier )
176 {
177 data->getNormCatResponses().copyTo(w->cat_responses);
178 data->getClassLabels().copyTo(classLabels);
179 int nclasses = (int)classLabels.size();
180
181 Mat class_weights = params.priors;
182 if( !class_weights.empty() )
183 {
184 if( class_weights.type() != CV_64F || !class_weights.isContinuous() )
185 {
186 Mat temp;
187 class_weights.convertTo(temp, CV_64F);
188 class_weights = temp;
189 }
190 CV_Assert( class_weights.checkVector(1, CV_64F) == nclasses );
191
192 int nsamples = (int)w->cat_responses.size();
193 const double* cw = class_weights.ptr<double>();
194 CV_Assert( (int)w->sample_weights.size() == nsamples );
195
196 for( i = 0; i < nsamples; i++ )
197 {
198 int ci = w->cat_responses[i];
199 CV_Assert( 0 <= ci && ci < nclasses );
200 w->sample_weights[i] *= cw[ci];
201 }
202 }
203 }
204 else
205 data->getResponses().copyTo(w->ord_responses);
206 }
207
208
initCompVarIdx()209 void DTreesImpl::initCompVarIdx()
210 {
211 int nallvars = (int)varType.size();
212 compVarIdx.assign(nallvars, -1);
213 int i, nvars = (int)varIdx.size(), prevIdx = -1;
214 for( i = 0; i < nvars; i++ )
215 {
216 int vi = varIdx[i];
217 CV_Assert( 0 <= vi && vi < nallvars && vi > prevIdx );
218 prevIdx = vi;
219 compVarIdx[vi] = i;
220 }
221 }
222
endTraining()223 void DTreesImpl::endTraining()
224 {
225 w.release();
226 }
227
train(const Ptr<TrainData> & trainData,int flags)228 bool DTreesImpl::train( const Ptr<TrainData>& trainData, int flags )
229 {
230 CV_Assert(!trainData.empty());
231 startTraining(trainData, flags);
232 bool ok = addTree( w->sidx ) >= 0;
233 w.release();
234 endTraining();
235 return ok;
236 }
237
getActiveVars()238 const vector<int>& DTreesImpl::getActiveVars()
239 {
240 return varIdx;
241 }
242
addTree(const vector<int> & sidx)243 int DTreesImpl::addTree(const vector<int>& sidx )
244 {
245 size_t n = (params.getMaxDepth() > 0 ? (1 << params.getMaxDepth()) : 1024) + w->wnodes.size();
246
247 w->wnodes.reserve(n);
248 w->wsplits.reserve(n);
249 w->wsubsets.reserve(n*w->maxSubsetSize);
250 w->wnodes.clear();
251 w->wsplits.clear();
252 w->wsubsets.clear();
253
254 int cv_n = params.getCVFolds();
255
256 if( cv_n > 0 )
257 {
258 w->cv_Tn.resize(n*cv_n);
259 w->cv_node_error.resize(n*cv_n);
260 w->cv_node_risk.resize(n*cv_n);
261 }
262
263 // build the tree recursively
264 int w_root = addNodeAndTrySplit(-1, sidx);
265 int maxdepth = INT_MAX;//pruneCV(root);
266
267 int w_nidx = w_root, pidx = -1, depth = 0;
268 int root = (int)nodes.size();
269
270 for(;;)
271 {
272 const WNode& wnode = w->wnodes[w_nidx];
273 Node node;
274 node.parent = pidx;
275 node.classIdx = wnode.class_idx;
276 node.value = wnode.value;
277 node.defaultDir = wnode.defaultDir;
278
279 int wsplit_idx = wnode.split;
280 if( wsplit_idx >= 0 )
281 {
282 const WSplit& wsplit = w->wsplits[wsplit_idx];
283 Split split;
284 split.c = wsplit.c;
285 split.quality = wsplit.quality;
286 split.inversed = wsplit.inversed;
287 split.varIdx = wsplit.varIdx;
288 split.subsetOfs = -1;
289 if( wsplit.subsetOfs >= 0 )
290 {
291 int ssize = getSubsetSize(split.varIdx);
292 split.subsetOfs = (int)subsets.size();
293 subsets.resize(split.subsetOfs + ssize);
294 // This check verifies that subsets index is in the correct range
295 // as in case ssize == 0 no real resize performed.
296 // Thus memory kept safe.
297 // Also this skips useless memcpy call when size parameter is zero
298 if(ssize > 0)
299 {
300 memcpy(&subsets[split.subsetOfs], &w->wsubsets[wsplit.subsetOfs], ssize*sizeof(int));
301 }
302 }
303 node.split = (int)splits.size();
304 splits.push_back(split);
305 }
306 int nidx = (int)nodes.size();
307 nodes.push_back(node);
308 if( pidx >= 0 )
309 {
310 int w_pidx = w->wnodes[w_nidx].parent;
311 if( w->wnodes[w_pidx].left == w_nidx )
312 {
313 nodes[pidx].left = nidx;
314 }
315 else
316 {
317 CV_Assert(w->wnodes[w_pidx].right == w_nidx);
318 nodes[pidx].right = nidx;
319 }
320 }
321
322 if( wnode.left >= 0 && depth+1 < maxdepth )
323 {
324 w_nidx = wnode.left;
325 pidx = nidx;
326 depth++;
327 }
328 else
329 {
330 int w_pidx = wnode.parent;
331 while( w_pidx >= 0 && w->wnodes[w_pidx].right == w_nidx )
332 {
333 w_nidx = w_pidx;
334 w_pidx = w->wnodes[w_pidx].parent;
335 nidx = pidx;
336 pidx = nodes[pidx].parent;
337 depth--;
338 }
339
340 if( w_pidx < 0 )
341 break;
342
343 w_nidx = w->wnodes[w_pidx].right;
344 CV_Assert( w_nidx >= 0 );
345 }
346 }
347 roots.push_back(root);
348 return root;
349 }
350
setDParams(const TreeParams & _params)351 void DTreesImpl::setDParams(const TreeParams& _params)
352 {
353 params = _params;
354 }
355
addNodeAndTrySplit(int parent,const vector<int> & sidx)356 int DTreesImpl::addNodeAndTrySplit( int parent, const vector<int>& sidx )
357 {
358 w->wnodes.push_back(WNode());
359 int nidx = (int)(w->wnodes.size() - 1);
360 WNode& node = w->wnodes.back();
361
362 node.parent = parent;
363 node.depth = parent >= 0 ? w->wnodes[parent].depth + 1 : 0;
364 int nfolds = params.getCVFolds();
365
366 if( nfolds > 0 )
367 {
368 w->cv_Tn.resize((nidx+1)*nfolds);
369 w->cv_node_error.resize((nidx+1)*nfolds);
370 w->cv_node_risk.resize((nidx+1)*nfolds);
371 }
372
373 int i, n = node.sample_count = (int)sidx.size();
374 bool can_split = true;
375 vector<int> sleft, sright;
376
377 calcValue( nidx, sidx );
378
379 if( n <= params.getMinSampleCount() || node.depth >= params.getMaxDepth() )
380 can_split = false;
381 else if( _isClassifier )
382 {
383 const int* responses = &w->cat_responses[0];
384 const int* s = &sidx[0];
385 int first = responses[s[0]];
386 for( i = 1; i < n; i++ )
387 if( responses[s[i]] != first )
388 break;
389 if( i == n )
390 can_split = false;
391 }
392 else
393 {
394 if( sqrt(node.node_risk) < params.getRegressionAccuracy() )
395 can_split = false;
396 }
397
398 if( can_split )
399 node.split = findBestSplit( sidx );
400
401 //printf("depth=%d, nidx=%d, parent=%d, n=%d, %s, value=%.1f, risk=%.1f\n", node.depth, nidx, node.parent, n, (node.split < 0 ? "leaf" : varType[w->wsplits[node.split].varIdx] == VAR_CATEGORICAL ? "cat" : "ord"), node.value, node.node_risk);
402
403 if( node.split >= 0 )
404 {
405 node.defaultDir = calcDir( node.split, sidx, sleft, sright );
406 if( params.useSurrogates )
407 CV_Error( CV_StsNotImplemented, "surrogate splits are not implemented yet");
408
409 int left = addNodeAndTrySplit( nidx, sleft );
410 int right = addNodeAndTrySplit( nidx, sright );
411 w->wnodes[nidx].left = left;
412 w->wnodes[nidx].right = right;
413 CV_Assert( w->wnodes[nidx].left > 0 && w->wnodes[nidx].right > 0 );
414 }
415
416 return nidx;
417 }
418
findBestSplit(const vector<int> & _sidx)419 int DTreesImpl::findBestSplit( const vector<int>& _sidx )
420 {
421 const vector<int>& activeVars = getActiveVars();
422 int splitidx = -1;
423 int vi_, nv = (int)activeVars.size();
424 AutoBuffer<int> buf(w->maxSubsetSize*2);
425 int *subset = buf.data(), *best_subset = subset + w->maxSubsetSize;
426 WSplit split, best_split;
427 best_split.quality = 0.;
428
429 for( vi_ = 0; vi_ < nv; vi_++ )
430 {
431 int vi = activeVars[vi_];
432 if( varType[vi] == VAR_CATEGORICAL )
433 {
434 if( _isClassifier )
435 split = findSplitCatClass(vi, _sidx, 0, subset);
436 else
437 split = findSplitCatReg(vi, _sidx, 0, subset);
438 }
439 else
440 {
441 if( _isClassifier )
442 split = findSplitOrdClass(vi, _sidx, 0);
443 else
444 split = findSplitOrdReg(vi, _sidx, 0);
445 }
446 if( split.quality > best_split.quality )
447 {
448 best_split = split;
449 std::swap(subset, best_subset);
450 }
451 }
452
453 if( best_split.quality > 0 )
454 {
455 int best_vi = best_split.varIdx;
456 CV_Assert( compVarIdx[best_split.varIdx] >= 0 && best_vi >= 0 );
457 int i, prevsz = (int)w->wsubsets.size(), ssize = getSubsetSize(best_vi);
458 w->wsubsets.resize(prevsz + ssize);
459 for( i = 0; i < ssize; i++ )
460 w->wsubsets[prevsz + i] = best_subset[i];
461 best_split.subsetOfs = prevsz;
462 w->wsplits.push_back(best_split);
463 splitidx = (int)(w->wsplits.size()-1);
464 }
465
466 return splitidx;
467 }
468
calcValue(int nidx,const vector<int> & _sidx)469 void DTreesImpl::calcValue( int nidx, const vector<int>& _sidx )
470 {
471 WNode* node = &w->wnodes[nidx];
472 int i, j, k, n = (int)_sidx.size(), cv_n = params.getCVFolds();
473 int m = (int)classLabels.size();
474
475 cv::AutoBuffer<double> buf(std::max(m, 3)*(cv_n+1));
476
477 if( cv_n > 0 )
478 {
479 size_t sz = w->cv_Tn.size();
480 w->cv_Tn.resize(sz + cv_n);
481 w->cv_node_risk.resize(sz + cv_n);
482 w->cv_node_error.resize(sz + cv_n);
483 }
484
485 if( _isClassifier )
486 {
487 // in case of classification tree:
488 // * node value is the label of the class that has the largest weight in the node.
489 // * node risk is the weighted number of misclassified samples,
490 // * j-th cross-validation fold value and risk are calculated as above,
491 // but using the samples with cv_labels(*)!=j.
492 // * j-th cross-validation fold error is calculated as the weighted number of
493 // misclassified samples with cv_labels(*)==j.
494
495 // compute the number of instances of each class
496 double* cls_count = buf.data();
497 double* cv_cls_count = cls_count + m;
498
499 double max_val = -1, total_weight = 0;
500 int max_k = -1;
501
502 for( k = 0; k < m; k++ )
503 cls_count[k] = 0;
504
505 if( cv_n == 0 )
506 {
507 for( i = 0; i < n; i++ )
508 {
509 int si = _sidx[i];
510 cls_count[w->cat_responses[si]] += w->sample_weights[si];
511 }
512 }
513 else
514 {
515 for( j = 0; j < cv_n; j++ )
516 for( k = 0; k < m; k++ )
517 cv_cls_count[j*m + k] = 0;
518
519 for( i = 0; i < n; i++ )
520 {
521 int si = _sidx[i];
522 j = w->cv_labels[si]; k = w->cat_responses[si];
523 cv_cls_count[j*m + k] += w->sample_weights[si];
524 }
525
526 for( j = 0; j < cv_n; j++ )
527 for( k = 0; k < m; k++ )
528 cls_count[k] += cv_cls_count[j*m + k];
529 }
530
531 for( k = 0; k < m; k++ )
532 {
533 double val = cls_count[k];
534 total_weight += val;
535 if( max_val < val )
536 {
537 max_val = val;
538 max_k = k;
539 }
540 }
541
542 node->class_idx = max_k;
543 node->value = classLabels[max_k];
544 node->node_risk = total_weight - max_val;
545
546 for( j = 0; j < cv_n; j++ )
547 {
548 double sum_k = 0, sum = 0, max_val_k = 0;
549 max_val = -1; max_k = -1;
550
551 for( k = 0; k < m; k++ )
552 {
553 double val_k = cv_cls_count[j*m + k];
554 double val = cls_count[k] - val_k;
555 sum_k += val_k;
556 sum += val;
557 if( max_val < val )
558 {
559 max_val = val;
560 max_val_k = val_k;
561 max_k = k;
562 }
563 }
564
565 w->cv_Tn[nidx*cv_n + j] = INT_MAX;
566 w->cv_node_risk[nidx*cv_n + j] = sum - max_val;
567 w->cv_node_error[nidx*cv_n + j] = sum_k - max_val_k;
568 }
569 }
570 else
571 {
572 // in case of regression tree:
573 // * node value is 1/n*sum_i(Y_i), where Y_i is i-th response,
574 // n is the number of samples in the node.
575 // * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2)
576 // * j-th cross-validation fold value and risk are calculated as above,
577 // but using the samples with cv_labels(*)!=j.
578 // * j-th cross-validation fold error is calculated
579 // using samples with cv_labels(*)==j as the test subset:
580 // error_j = sum_(i,cv_labels(i)==j)((Y_i - <node_value_j>)^2),
581 // where node_value_j is the node value calculated
582 // as described in the previous bullet, and summation is done
583 // over the samples with cv_labels(*)==j.
584 double sum = 0, sum2 = 0, sumw = 0;
585
586 if( cv_n == 0 )
587 {
588 for( i = 0; i < n; i++ )
589 {
590 int si = _sidx[i];
591 double wval = w->sample_weights[si];
592 double t = w->ord_responses[si];
593 sum += t*wval;
594 sum2 += t*t*wval;
595 sumw += wval;
596 }
597 }
598 else
599 {
600 double *cv_sum = buf.data(), *cv_sum2 = cv_sum + cv_n;
601 double* cv_count = (double*)(cv_sum2 + cv_n);
602
603 for( j = 0; j < cv_n; j++ )
604 {
605 cv_sum[j] = cv_sum2[j] = 0.;
606 cv_count[j] = 0;
607 }
608
609 for( i = 0; i < n; i++ )
610 {
611 int si = _sidx[i];
612 j = w->cv_labels[si];
613 double wval = w->sample_weights[si];
614 double t = w->ord_responses[si];
615 cv_sum[j] += t*wval;
616 cv_sum2[j] += t*t*wval;
617 cv_count[j] += wval;
618 }
619
620 for( j = 0; j < cv_n; j++ )
621 {
622 sum += cv_sum[j];
623 sum2 += cv_sum2[j];
624 sumw += cv_count[j];
625 }
626
627 for( j = 0; j < cv_n; j++ )
628 {
629 double s = sum - cv_sum[j], si = sum - s;
630 double s2 = sum2 - cv_sum2[j], s2i = sum2 - s2;
631 double c = cv_count[j], ci = sumw - c;
632 double r = si/std::max(ci, DBL_EPSILON);
633 w->cv_node_risk[nidx*cv_n + j] = s2i - r*r*ci;
634 w->cv_node_error[nidx*cv_n + j] = s2 - 2*r*s + c*r*r;
635 w->cv_Tn[nidx*cv_n + j] = INT_MAX;
636 }
637 }
638 CV_Assert(fabs(sumw) > 0);
639 node->node_risk = sum2 - (sum/sumw)*sum;
640 node->node_risk /= sumw;
641 node->value = sum/sumw;
642 }
643 }
644
findSplitOrdClass(int vi,const vector<int> & _sidx,double initQuality)645 DTreesImpl::WSplit DTreesImpl::findSplitOrdClass( int vi, const vector<int>& _sidx, double initQuality )
646 {
647 int n = (int)_sidx.size();
648 int m = (int)classLabels.size();
649
650 cv::AutoBuffer<uchar> buf(n*(sizeof(float) + sizeof(int)) + m*2*sizeof(double));
651 const int* sidx = &_sidx[0];
652 const int* responses = &w->cat_responses[0];
653 const double* weights = &w->sample_weights[0];
654 double* lcw = (double*)buf.data();
655 double* rcw = lcw + m;
656 float* values = (float*)(rcw + m);
657 int* sorted_idx = (int*)(values + n);
658 int i, best_i = -1;
659 double best_val = initQuality;
660
661 for( i = 0; i < m; i++ )
662 lcw[i] = rcw[i] = 0.;
663
664 w->data->getValues( vi, _sidx, values );
665
666 for( i = 0; i < n; i++ )
667 {
668 sorted_idx[i] = i;
669 int si = sidx[i];
670 rcw[responses[si]] += weights[si];
671 }
672
673 std::sort(sorted_idx, sorted_idx + n, cmp_lt_idx<float>(values));
674
675 double L = 0, R = 0, lsum2 = 0, rsum2 = 0;
676 for( i = 0; i < m; i++ )
677 {
678 double wval = rcw[i];
679 R += wval;
680 rsum2 += wval*wval;
681 }
682
683 for( i = 0; i < n - 1; i++ )
684 {
685 int curr = sorted_idx[i];
686 int next = sorted_idx[i+1];
687 int si = sidx[curr];
688 double wval = weights[si], w2 = wval*wval;
689 L += wval; R -= wval;
690 int idx = responses[si];
691 double lv = lcw[idx], rv = rcw[idx];
692 lsum2 += 2*lv*wval + w2;
693 rsum2 -= 2*rv*wval - w2;
694 lcw[idx] = lv + wval; rcw[idx] = rv - wval;
695
696 float value_between = (values[next] + values[curr]) * 0.5f;
697 if( value_between > values[curr] && value_between < values[next] )
698 {
699 double val = (lsum2*R + rsum2*L)/(L*R);
700 if( best_val < val )
701 {
702 best_val = val;
703 best_i = i;
704 }
705 }
706 }
707
708 WSplit split;
709 if( best_i >= 0 )
710 {
711 split.varIdx = vi;
712 split.c = (values[sorted_idx[best_i]] + values[sorted_idx[best_i+1]])*0.5f;
713 split.inversed = false;
714 split.quality = (float)best_val;
715 }
716 return split;
717 }
718
719 // simple k-means, slightly modified to take into account the "weight" (L1-norm) of each vector.
clusterCategories(const double * vectors,int n,int m,double * csums,int k,int * labels)720 void DTreesImpl::clusterCategories( const double* vectors, int n, int m, double* csums, int k, int* labels )
721 {
722 int iters = 0, max_iters = 100;
723 int i, j, idx;
724 cv::AutoBuffer<double> buf(n + k);
725 double *v_weights = buf.data(), *c_weights = buf.data() + n;
726 bool modified = true;
727 RNG r((uint64)-1);
728
729 // assign labels randomly
730 for( i = 0; i < n; i++ )
731 {
732 double sum = 0;
733 const double* v = vectors + i*m;
734 labels[i] = i < k ? i : r.uniform(0, k);
735
736 // compute weight of each vector
737 for( j = 0; j < m; j++ )
738 sum += v[j];
739 v_weights[i] = sum ? 1./sum : 0.;
740 }
741
742 for( i = 0; i < n; i++ )
743 {
744 int i1 = r.uniform(0, n);
745 int i2 = r.uniform(0, n);
746 std::swap( labels[i1], labels[i2] );
747 }
748
749 for( iters = 0; iters <= max_iters; iters++ )
750 {
751 // calculate csums
752 for( i = 0; i < k; i++ )
753 {
754 for( j = 0; j < m; j++ )
755 csums[i*m + j] = 0;
756 }
757
758 for( i = 0; i < n; i++ )
759 {
760 const double* v = vectors + i*m;
761 double* s = csums + labels[i]*m;
762 for( j = 0; j < m; j++ )
763 s[j] += v[j];
764 }
765
766 // exit the loop here, when we have up-to-date csums
767 if( iters == max_iters || !modified )
768 break;
769
770 modified = false;
771
772 // calculate weight of each cluster
773 for( i = 0; i < k; i++ )
774 {
775 const double* s = csums + i*m;
776 double sum = 0;
777 for( j = 0; j < m; j++ )
778 sum += s[j];
779 c_weights[i] = sum ? 1./sum : 0;
780 }
781
782 // now for each vector determine the closest cluster
783 for( i = 0; i < n; i++ )
784 {
785 const double* v = vectors + i*m;
786 double alpha = v_weights[i];
787 double min_dist2 = DBL_MAX;
788 int min_idx = -1;
789
790 for( idx = 0; idx < k; idx++ )
791 {
792 const double* s = csums + idx*m;
793 double dist2 = 0., beta = c_weights[idx];
794 for( j = 0; j < m; j++ )
795 {
796 double t = v[j]*alpha - s[j]*beta;
797 dist2 += t*t;
798 }
799 if( min_dist2 > dist2 )
800 {
801 min_dist2 = dist2;
802 min_idx = idx;
803 }
804 }
805
806 if( min_idx != labels[i] )
807 modified = true;
808 labels[i] = min_idx;
809 }
810 }
811 }
812
findSplitCatClass(int vi,const vector<int> & _sidx,double initQuality,int * subset)813 DTreesImpl::WSplit DTreesImpl::findSplitCatClass( int vi, const vector<int>& _sidx,
814 double initQuality, int* subset )
815 {
816 int _mi = getCatCount(vi), mi = _mi;
817 int n = (int)_sidx.size();
818 int m = (int)classLabels.size();
819
820 int base_size = m*(3 + mi) + mi + 1;
821 if( m > 2 && mi > params.getMaxCategories() )
822 base_size += m*std::min(params.getMaxCategories(), n) + mi;
823 else
824 base_size += mi;
825 AutoBuffer<double> buf(base_size + n);
826
827 double* lc = buf.data();
828 double* rc = lc + m;
829 double* _cjk = rc + m*2, *cjk = _cjk;
830 double* c_weights = cjk + m*mi;
831
832 int* labels = (int*)(buf.data() + base_size);
833 w->data->getNormCatValues(vi, _sidx, labels);
834 const int* responses = &w->cat_responses[0];
835 const double* weights = &w->sample_weights[0];
836
837 int* cluster_labels = 0;
838 double** dbl_ptr = 0;
839 int i, j, k, si, idx;
840 double L = 0, R = 0;
841 double best_val = initQuality;
842 int prevcode = 0, best_subset = -1, subset_i, subset_n, subtract = 0;
843
844 // init array of counters:
845 // c_{jk} - number of samples that have vi-th input variable = j and response = k.
846 for( j = -1; j < mi; j++ )
847 for( k = 0; k < m; k++ )
848 cjk[j*m + k] = 0;
849
850 for( i = 0; i < n; i++ )
851 {
852 si = _sidx[i];
853 j = labels[i];
854 k = responses[si];
855 cjk[j*m + k] += weights[si];
856 }
857
858 if( m > 2 )
859 {
860 if( mi > params.getMaxCategories() )
861 {
862 mi = std::min(params.getMaxCategories(), n);
863 cjk = c_weights + _mi;
864 cluster_labels = (int*)(cjk + m*mi);
865 clusterCategories( _cjk, _mi, m, cjk, mi, cluster_labels );
866 }
867 subset_i = 1;
868 subset_n = 1 << mi;
869 }
870 else
871 {
872 assert( m == 2 );
873 dbl_ptr = (double**)(c_weights + _mi);
874 for( j = 0; j < mi; j++ )
875 dbl_ptr[j] = cjk + j*2 + 1;
876 std::sort(dbl_ptr, dbl_ptr + mi, cmp_lt_ptr<double>());
877 subset_i = 0;
878 subset_n = mi;
879 }
880
881 for( k = 0; k < m; k++ )
882 {
883 double sum = 0;
884 for( j = 0; j < mi; j++ )
885 sum += cjk[j*m + k];
886 CV_Assert(sum > 0);
887 rc[k] = sum;
888 lc[k] = 0;
889 }
890
891 for( j = 0; j < mi; j++ )
892 {
893 double sum = 0;
894 for( k = 0; k < m; k++ )
895 sum += cjk[j*m + k];
896 c_weights[j] = sum;
897 R += c_weights[j];
898 }
899
900 for( ; subset_i < subset_n; subset_i++ )
901 {
902 double lsum2 = 0, rsum2 = 0;
903
904 if( m == 2 )
905 idx = (int)(dbl_ptr[subset_i] - cjk)/2;
906 else
907 {
908 int graycode = (subset_i>>1)^subset_i;
909 int diff = graycode ^ prevcode;
910
911 // determine index of the changed bit.
912 Cv32suf u;
913 idx = diff >= (1 << 16) ? 16 : 0;
914 u.f = (float)(((diff >> 16) | diff) & 65535);
915 idx += (u.i >> 23) - 127;
916 subtract = graycode < prevcode;
917 prevcode = graycode;
918 }
919
920 double* crow = cjk + idx*m;
921 double weight = c_weights[idx];
922 if( weight < FLT_EPSILON )
923 continue;
924
925 if( !subtract )
926 {
927 for( k = 0; k < m; k++ )
928 {
929 double t = crow[k];
930 double lval = lc[k] + t;
931 double rval = rc[k] - t;
932 lsum2 += lval*lval;
933 rsum2 += rval*rval;
934 lc[k] = lval; rc[k] = rval;
935 }
936 L += weight;
937 R -= weight;
938 }
939 else
940 {
941 for( k = 0; k < m; k++ )
942 {
943 double t = crow[k];
944 double lval = lc[k] - t;
945 double rval = rc[k] + t;
946 lsum2 += lval*lval;
947 rsum2 += rval*rval;
948 lc[k] = lval; rc[k] = rval;
949 }
950 L -= weight;
951 R += weight;
952 }
953
954 if( L > FLT_EPSILON && R > FLT_EPSILON )
955 {
956 double val = (lsum2*R + rsum2*L)/(L*R);
957 if( best_val < val )
958 {
959 best_val = val;
960 best_subset = subset_i;
961 }
962 }
963 }
964
965 WSplit split;
966 if( best_subset >= 0 )
967 {
968 split.varIdx = vi;
969 split.quality = (float)best_val;
970 memset( subset, 0, getSubsetSize(vi) * sizeof(int) );
971 if( m == 2 )
972 {
973 for( i = 0; i <= best_subset; i++ )
974 {
975 idx = (int)(dbl_ptr[i] - cjk) >> 1;
976 subset[idx >> 5] |= 1 << (idx & 31);
977 }
978 }
979 else
980 {
981 for( i = 0; i < _mi; i++ )
982 {
983 idx = cluster_labels ? cluster_labels[i] : i;
984 if( best_subset & (1 << idx) )
985 subset[i >> 5] |= 1 << (i & 31);
986 }
987 }
988 }
989 return split;
990 }
991
findSplitOrdReg(int vi,const vector<int> & _sidx,double initQuality)992 DTreesImpl::WSplit DTreesImpl::findSplitOrdReg( int vi, const vector<int>& _sidx, double initQuality )
993 {
994 const double* weights = &w->sample_weights[0];
995 int n = (int)_sidx.size();
996
997 AutoBuffer<uchar> buf(n*(sizeof(int) + sizeof(float)));
998
999 float* values = (float*)buf.data();
1000 int* sorted_idx = (int*)(values + n);
1001 w->data->getValues(vi, _sidx, values);
1002 const double* responses = &w->ord_responses[0];
1003
1004 int i, si, best_i = -1;
1005 double L = 0, R = 0;
1006 double best_val = initQuality, lsum = 0, rsum = 0;
1007
1008 for( i = 0; i < n; i++ )
1009 {
1010 sorted_idx[i] = i;
1011 si = _sidx[i];
1012 R += weights[si];
1013 rsum += weights[si]*responses[si];
1014 }
1015
1016 std::sort(sorted_idx, sorted_idx + n, cmp_lt_idx<float>(values));
1017
1018 // find the optimal split
1019 for( i = 0; i < n - 1; i++ )
1020 {
1021 int curr = sorted_idx[i];
1022 int next = sorted_idx[i+1];
1023 si = _sidx[curr];
1024 double wval = weights[si];
1025 double t = responses[si]*wval;
1026 L += wval; R -= wval;
1027 lsum += t; rsum -= t;
1028
1029 float value_between = (values[next] + values[curr]) * 0.5f;
1030 if( value_between > values[curr] && value_between < values[next] )
1031 {
1032 double val = (lsum*lsum*R + rsum*rsum*L)/(L*R);
1033 if( best_val < val )
1034 {
1035 best_val = val;
1036 best_i = i;
1037 }
1038 }
1039 }
1040
1041 WSplit split;
1042 if( best_i >= 0 )
1043 {
1044 split.varIdx = vi;
1045 split.c = (values[sorted_idx[best_i]] + values[sorted_idx[best_i+1]])*0.5f;
1046 split.inversed = false;
1047 split.quality = (float)best_val;
1048 }
1049 return split;
1050 }
1051
findSplitCatReg(int vi,const vector<int> & _sidx,double initQuality,int * subset)1052 DTreesImpl::WSplit DTreesImpl::findSplitCatReg( int vi, const vector<int>& _sidx,
1053 double initQuality, int* subset )
1054 {
1055 const double* weights = &w->sample_weights[0];
1056 const double* responses = &w->ord_responses[0];
1057 int n = (int)_sidx.size();
1058 int mi = getCatCount(vi);
1059
1060 AutoBuffer<double> buf(3*mi + 3 + n);
1061 double* sum = buf.data() + 1;
1062 double* counts = sum + mi + 1;
1063 double** sum_ptr = (double**)(counts + mi);
1064 int* cat_labels = (int*)(sum_ptr + mi);
1065
1066 w->data->getNormCatValues(vi, _sidx, cat_labels);
1067
1068 double L = 0, R = 0, best_val = initQuality, lsum = 0, rsum = 0;
1069 int i, si, best_subset = -1, subset_i;
1070
1071 for( i = -1; i < mi; i++ )
1072 sum[i] = counts[i] = 0;
1073
1074 // calculate sum response and weight of each category of the input var
1075 for( i = 0; i < n; i++ )
1076 {
1077 int idx = cat_labels[i];
1078 si = _sidx[i];
1079 double wval = weights[si];
1080 sum[idx] += responses[si]*wval;
1081 counts[idx] += wval;
1082 }
1083
1084 // calculate average response in each category
1085 for( i = 0; i < mi; i++ )
1086 {
1087 R += counts[i];
1088 rsum += sum[i];
1089 sum[i] = fabs(counts[i]) > DBL_EPSILON ? sum[i]/counts[i] : 0;
1090 sum_ptr[i] = sum + i;
1091 }
1092
1093 std::sort(sum_ptr, sum_ptr + mi, cmp_lt_ptr<double>());
1094
1095 // revert back to unnormalized sums
1096 // (there should be a very little loss in accuracy)
1097 for( i = 0; i < mi; i++ )
1098 sum[i] *= counts[i];
1099
1100 for( subset_i = 0; subset_i < mi-1; subset_i++ )
1101 {
1102 int idx = (int)(sum_ptr[subset_i] - sum);
1103 double ni = counts[idx];
1104
1105 if( ni > FLT_EPSILON )
1106 {
1107 double s = sum[idx];
1108 lsum += s; L += ni;
1109 rsum -= s; R -= ni;
1110
1111 if( L > FLT_EPSILON && R > FLT_EPSILON )
1112 {
1113 double val = (lsum*lsum*R + rsum*rsum*L)/(L*R);
1114 if( best_val < val )
1115 {
1116 best_val = val;
1117 best_subset = subset_i;
1118 }
1119 }
1120 }
1121 }
1122
1123 WSplit split;
1124 if( best_subset >= 0 )
1125 {
1126 split.varIdx = vi;
1127 split.quality = (float)best_val;
1128 memset( subset, 0, getSubsetSize(vi) * sizeof(int));
1129 for( i = 0; i <= best_subset; i++ )
1130 {
1131 int idx = (int)(sum_ptr[i] - sum);
1132 subset[idx >> 5] |= 1 << (idx & 31);
1133 }
1134 }
1135 return split;
1136 }
1137
calcDir(int splitidx,const vector<int> & _sidx,vector<int> & _sleft,vector<int> & _sright)1138 int DTreesImpl::calcDir( int splitidx, const vector<int>& _sidx,
1139 vector<int>& _sleft, vector<int>& _sright )
1140 {
1141 WSplit split = w->wsplits[splitidx];
1142 int i, si, n = (int)_sidx.size(), vi = split.varIdx;
1143 _sleft.reserve(n);
1144 _sright.reserve(n);
1145 _sleft.clear();
1146 _sright.clear();
1147
1148 AutoBuffer<float> buf(n);
1149 int mi = getCatCount(vi);
1150 double wleft = 0, wright = 0;
1151 const double* weights = &w->sample_weights[0];
1152
1153 if( mi <= 0 ) // split on an ordered variable
1154 {
1155 float c = split.c;
1156 float* values = buf.data();
1157 w->data->getValues(vi, _sidx, values);
1158
1159 for( i = 0; i < n; i++ )
1160 {
1161 si = _sidx[i];
1162 if( values[i] <= c )
1163 {
1164 _sleft.push_back(si);
1165 wleft += weights[si];
1166 }
1167 else
1168 {
1169 _sright.push_back(si);
1170 wright += weights[si];
1171 }
1172 }
1173 }
1174 else
1175 {
1176 const int* subset = &w->wsubsets[split.subsetOfs];
1177 int* cat_labels = (int*)buf.data();
1178 w->data->getNormCatValues(vi, _sidx, cat_labels);
1179
1180 for( i = 0; i < n; i++ )
1181 {
1182 si = _sidx[i];
1183 unsigned u = cat_labels[i];
1184 if( CV_DTREE_CAT_DIR(u, subset) < 0 )
1185 {
1186 _sleft.push_back(si);
1187 wleft += weights[si];
1188 }
1189 else
1190 {
1191 _sright.push_back(si);
1192 wright += weights[si];
1193 }
1194 }
1195 }
1196 CV_Assert( (int)_sleft.size() < n && (int)_sright.size() < n );
1197 return wleft > wright ? -1 : 1;
1198 }
1199
pruneCV(int root)1200 int DTreesImpl::pruneCV( int root )
1201 {
1202 vector<double> ab;
1203
1204 // 1. build tree sequence for each cv fold, calculate error_{Tj,beta_k}.
1205 // 2. choose the best tree index (if need, apply 1SE rule).
1206 // 3. store the best index and cut the branches.
1207
1208 int ti, tree_count = 0, j, cv_n = params.getCVFolds(), n = w->wnodes[root].sample_count;
1209 // currently, 1SE for regression is not implemented
1210 bool use_1se = params.use1SERule != 0 && _isClassifier;
1211 double min_err = 0, min_err_se = 0;
1212 int min_idx = -1;
1213
1214 // build the main tree sequence, calculate alpha's
1215 for(;;tree_count++)
1216 {
1217 double min_alpha = updateTreeRNC(root, tree_count, -1);
1218 if( cutTree(root, tree_count, -1, min_alpha) )
1219 break;
1220
1221 ab.push_back(min_alpha);
1222 }
1223
1224 if( tree_count > 0 )
1225 {
1226 ab[0] = 0.;
1227
1228 for( ti = 1; ti < tree_count-1; ti++ )
1229 ab[ti] = std::sqrt(ab[ti]*ab[ti+1]);
1230 ab[tree_count-1] = DBL_MAX*0.5;
1231
1232 Mat err_jk(cv_n, tree_count, CV_64F);
1233
1234 for( j = 0; j < cv_n; j++ )
1235 {
1236 int tj = 0, tk = 0;
1237 for( ; tj < tree_count; tj++ )
1238 {
1239 double min_alpha = updateTreeRNC(root, tj, j);
1240 if( cutTree(root, tj, j, min_alpha) )
1241 min_alpha = DBL_MAX;
1242
1243 for( ; tk < tree_count; tk++ )
1244 {
1245 if( ab[tk] > min_alpha )
1246 break;
1247 err_jk.at<double>(j, tk) = w->wnodes[root].tree_error;
1248 }
1249 }
1250 }
1251
1252 for( ti = 0; ti < tree_count; ti++ )
1253 {
1254 double sum_err = 0;
1255 for( j = 0; j < cv_n; j++ )
1256 sum_err += err_jk.at<double>(j, ti);
1257 if( ti == 0 || sum_err < min_err )
1258 {
1259 min_err = sum_err;
1260 min_idx = ti;
1261 if( use_1se )
1262 min_err_se = sqrt( sum_err*(n - sum_err) );
1263 }
1264 else if( sum_err < min_err + min_err_se )
1265 min_idx = ti;
1266 }
1267 }
1268
1269 return min_idx;
1270 }
1271
updateTreeRNC(int root,double T,int fold)1272 double DTreesImpl::updateTreeRNC( int root, double T, int fold )
1273 {
1274 int nidx = root, pidx = -1, cv_n = params.getCVFolds();
1275 double min_alpha = DBL_MAX;
1276
1277 for(;;)
1278 {
1279 WNode *node = 0, *parent = 0;
1280
1281 for(;;)
1282 {
1283 node = &w->wnodes[nidx];
1284 double t = fold >= 0 ? w->cv_Tn[nidx*cv_n + fold] : node->Tn;
1285 if( t <= T || node->left < 0 )
1286 {
1287 node->complexity = 1;
1288 node->tree_risk = node->node_risk;
1289 node->tree_error = 0.;
1290 if( fold >= 0 )
1291 {
1292 node->tree_risk = w->cv_node_risk[nidx*cv_n + fold];
1293 node->tree_error = w->cv_node_error[nidx*cv_n + fold];
1294 }
1295 break;
1296 }
1297 nidx = node->left;
1298 }
1299
1300 for( pidx = node->parent; pidx >= 0 && w->wnodes[pidx].right == nidx;
1301 nidx = pidx, pidx = w->wnodes[pidx].parent )
1302 {
1303 node = &w->wnodes[nidx];
1304 parent = &w->wnodes[pidx];
1305 parent->complexity += node->complexity;
1306 parent->tree_risk += node->tree_risk;
1307 parent->tree_error += node->tree_error;
1308
1309 parent->alpha = ((fold >= 0 ? w->cv_node_risk[pidx*cv_n + fold] : parent->node_risk)
1310 - parent->tree_risk)/(parent->complexity - 1);
1311 min_alpha = std::min( min_alpha, parent->alpha );
1312 }
1313
1314 if( pidx < 0 )
1315 break;
1316
1317 node = &w->wnodes[nidx];
1318 parent = &w->wnodes[pidx];
1319 parent->complexity = node->complexity;
1320 parent->tree_risk = node->tree_risk;
1321 parent->tree_error = node->tree_error;
1322 nidx = parent->right;
1323 }
1324
1325 return min_alpha;
1326 }
1327
cutTree(int root,double T,int fold,double min_alpha)1328 bool DTreesImpl::cutTree( int root, double T, int fold, double min_alpha )
1329 {
1330 int cv_n = params.getCVFolds(), nidx = root, pidx = -1;
1331 WNode* node = &w->wnodes[root];
1332 if( node->left < 0 )
1333 return true;
1334
1335 for(;;)
1336 {
1337 for(;;)
1338 {
1339 node = &w->wnodes[nidx];
1340 double t = fold >= 0 ? w->cv_Tn[nidx*cv_n + fold] : node->Tn;
1341 if( t <= T || node->left < 0 )
1342 break;
1343 if( node->alpha <= min_alpha + FLT_EPSILON )
1344 {
1345 if( fold >= 0 )
1346 w->cv_Tn[nidx*cv_n + fold] = T;
1347 else
1348 node->Tn = T;
1349 if( nidx == root )
1350 return true;
1351 break;
1352 }
1353 nidx = node->left;
1354 }
1355
1356 for( pidx = node->parent; pidx >= 0 && w->wnodes[pidx].right == nidx;
1357 nidx = pidx, pidx = w->wnodes[pidx].parent )
1358 ;
1359
1360 if( pidx < 0 )
1361 break;
1362
1363 nidx = w->wnodes[pidx].right;
1364 }
1365
1366 return false;
1367 }
1368
predictTrees(const Range & range,const Mat & sample,int flags) const1369 float DTreesImpl::predictTrees( const Range& range, const Mat& sample, int flags ) const
1370 {
1371 CV_Assert( sample.type() == CV_32F );
1372
1373 int predictType = flags & PREDICT_MASK;
1374 int nvars = (int)varIdx.size();
1375 if( nvars == 0 )
1376 nvars = (int)varType.size();
1377 int i, ncats = (int)catOfs.size(), nclasses = (int)classLabels.size();
1378 int catbufsize = ncats > 0 ? nvars : 0;
1379 AutoBuffer<int> buf(nclasses + catbufsize + 1);
1380 int* votes = buf.data();
1381 int* catbuf = votes + nclasses;
1382 const int* cvidx = (flags & (COMPRESSED_INPUT|PREPROCESSED_INPUT)) == 0 && !varIdx.empty() ? &compVarIdx[0] : 0;
1383 const uchar* vtype = &varType[0];
1384 const Vec2i* cofs = !catOfs.empty() ? &catOfs[0] : 0;
1385 const int* cmap = !catMap.empty() ? &catMap[0] : 0;
1386 const float* psample = sample.ptr<float>();
1387 const float* missingSubstPtr = !missingSubst.empty() ? &missingSubst[0] : 0;
1388 size_t sstep = sample.isContinuous() ? 1 : sample.step/sizeof(float);
1389 double sum = 0.;
1390 int lastClassIdx = -1;
1391 const float MISSED_VAL = TrainData::missingValue();
1392
1393 for( i = 0; i < catbufsize; i++ )
1394 catbuf[i] = -1;
1395
1396 if( predictType == PREDICT_AUTO )
1397 {
1398 predictType = !_isClassifier || (classLabels.size() == 2 && (flags & RAW_OUTPUT) != 0) ?
1399 PREDICT_SUM : PREDICT_MAX_VOTE;
1400 }
1401
1402 if( predictType == PREDICT_MAX_VOTE )
1403 {
1404 for( i = 0; i < nclasses; i++ )
1405 votes[i] = 0;
1406 }
1407
1408 for( int ridx = range.start; ridx < range.end; ridx++ )
1409 {
1410 int nidx = roots[ridx], prev = nidx, c = 0;
1411
1412 for(;;)
1413 {
1414 prev = nidx;
1415 const Node& node = nodes[nidx];
1416 if( node.split < 0 )
1417 break;
1418 const Split& split = splits[node.split];
1419 int vi = split.varIdx;
1420 int ci = cvidx ? cvidx[vi] : vi;
1421 float val = psample[ci*sstep];
1422 if( val == MISSED_VAL )
1423 {
1424 if( !missingSubstPtr )
1425 {
1426 nidx = node.defaultDir < 0 ? node.left : node.right;
1427 continue;
1428 }
1429 val = missingSubstPtr[vi];
1430 }
1431
1432 if( vtype[vi] == VAR_ORDERED )
1433 nidx = val <= split.c ? node.left : node.right;
1434 else
1435 {
1436 if( flags & PREPROCESSED_INPUT )
1437 c = cvRound(val);
1438 else
1439 {
1440 c = catbuf[ci];
1441 if( c < 0 )
1442 {
1443 int a = c = cofs[vi][0];
1444 int b = cofs[vi][1];
1445
1446 int ival = cvRound(val);
1447 if( ival != val )
1448 CV_Error( CV_StsBadArg,
1449 "one of input categorical variable is not an integer" );
1450
1451 CV_Assert(cmap != NULL);
1452 while( a < b )
1453 {
1454 c = (a + b) >> 1;
1455 if( ival < cmap[c] )
1456 b = c;
1457 else if( ival > cmap[c] )
1458 a = c+1;
1459 else
1460 break;
1461 }
1462
1463 CV_Assert( c >= 0 && ival == cmap[c] );
1464
1465 c -= cofs[vi][0];
1466 catbuf[ci] = c;
1467 }
1468 const int* subset = &subsets[split.subsetOfs];
1469 unsigned u = c;
1470 nidx = CV_DTREE_CAT_DIR(u, subset) < 0 ? node.left : node.right;
1471 }
1472 }
1473 }
1474
1475 if( predictType == PREDICT_SUM )
1476 sum += nodes[prev].value;
1477 else
1478 {
1479 lastClassIdx = nodes[prev].classIdx;
1480 votes[lastClassIdx]++;
1481 }
1482 }
1483
1484 if( predictType == PREDICT_MAX_VOTE )
1485 {
1486 int best_idx = lastClassIdx;
1487 if( range.end - range.start > 1 )
1488 {
1489 best_idx = 0;
1490 for( i = 1; i < nclasses; i++ )
1491 if( votes[best_idx] < votes[i] )
1492 best_idx = i;
1493 }
1494 sum = (flags & RAW_OUTPUT) ? (float)best_idx : classLabels[best_idx];
1495 }
1496
1497 return (float)sum;
1498 }
1499
1500
predict(InputArray _samples,OutputArray _results,int flags) const1501 float DTreesImpl::predict( InputArray _samples, OutputArray _results, int flags ) const
1502 {
1503 CV_Assert( !roots.empty() );
1504 Mat samples = _samples.getMat(), results;
1505 int i, nsamples = samples.rows;
1506 int rtype = CV_32F;
1507 bool needresults = _results.needed();
1508 float retval = 0.f;
1509 bool iscls = isClassifier();
1510 float scale = !iscls ? 1.f/(int)roots.size() : 1.f;
1511
1512 if( iscls && (flags & PREDICT_MASK) == PREDICT_MAX_VOTE )
1513 rtype = CV_32S;
1514
1515 if( needresults )
1516 {
1517 _results.create(nsamples, 1, rtype);
1518 results = _results.getMat();
1519 }
1520 else
1521 nsamples = std::min(nsamples, 1);
1522
1523 for( i = 0; i < nsamples; i++ )
1524 {
1525 float val = predictTrees( Range(0, (int)roots.size()), samples.row(i), flags )*scale;
1526 if( needresults )
1527 {
1528 if( rtype == CV_32F )
1529 results.at<float>(i) = val;
1530 else
1531 results.at<int>(i) = cvRound(val);
1532 }
1533 if( i == 0 )
1534 retval = val;
1535 }
1536 return retval;
1537 }
1538
writeTrainingParams(FileStorage & fs) const1539 void DTreesImpl::writeTrainingParams(FileStorage& fs) const
1540 {
1541 fs << "use_surrogates" << (params.useSurrogates ? 1 : 0);
1542 fs << "max_categories" << params.getMaxCategories();
1543 fs << "regression_accuracy" << params.getRegressionAccuracy();
1544
1545 fs << "max_depth" << params.getMaxDepth();
1546 fs << "min_sample_count" << params.getMinSampleCount();
1547 fs << "cross_validation_folds" << params.getCVFolds();
1548
1549 if( params.getCVFolds() > 1 )
1550 fs << "use_1se_rule" << (params.use1SERule ? 1 : 0);
1551
1552 if( !params.priors.empty() )
1553 fs << "priors" << params.priors;
1554 }
1555
writeParams(FileStorage & fs) const1556 void DTreesImpl::writeParams(FileStorage& fs) const
1557 {
1558 fs << "is_classifier" << isClassifier();
1559 fs << "var_all" << (int)varType.size();
1560 fs << "var_count" << getVarCount();
1561
1562 int ord_var_count = 0, cat_var_count = 0;
1563 int i, n = (int)varType.size();
1564 for( i = 0; i < n; i++ )
1565 if( varType[i] == VAR_ORDERED )
1566 ord_var_count++;
1567 else
1568 cat_var_count++;
1569 fs << "ord_var_count" << ord_var_count;
1570 fs << "cat_var_count" << cat_var_count;
1571
1572 fs << "training_params" << "{";
1573 writeTrainingParams(fs);
1574
1575 fs << "}";
1576
1577 if( !varIdx.empty() )
1578 {
1579 fs << "global_var_idx" << 1;
1580 fs << "var_idx" << varIdx;
1581 }
1582
1583 fs << "var_type" << varType;
1584
1585 if( !catOfs.empty() )
1586 fs << "cat_ofs" << catOfs;
1587 if( !catMap.empty() )
1588 fs << "cat_map" << catMap;
1589 if( !classLabels.empty() )
1590 fs << "class_labels" << classLabels;
1591 if( !missingSubst.empty() )
1592 fs << "missing_subst" << missingSubst;
1593 }
1594
writeSplit(FileStorage & fs,int splitidx) const1595 void DTreesImpl::writeSplit( FileStorage& fs, int splitidx ) const
1596 {
1597 const Split& split = splits[splitidx];
1598
1599 fs << "{:";
1600
1601 int vi = split.varIdx;
1602 fs << "var" << vi;
1603 fs << "quality" << split.quality;
1604
1605 if( varType[vi] == VAR_CATEGORICAL ) // split on a categorical var
1606 {
1607 int i, n = getCatCount(vi), to_right = 0;
1608 const int* subset = &subsets[split.subsetOfs];
1609 for( i = 0; i < n; i++ )
1610 to_right += CV_DTREE_CAT_DIR(i, subset) > 0;
1611
1612 // ad-hoc rule when to use inverse categorical split notation
1613 // to achieve more compact and clear representation
1614 int default_dir = to_right <= 1 || to_right <= std::min(3, n/2) || to_right <= n/3 ? -1 : 1;
1615
1616 fs << (default_dir*(split.inversed ? -1 : 1) > 0 ? "in" : "not_in") << "[:";
1617
1618 for( i = 0; i < n; i++ )
1619 {
1620 int dir = CV_DTREE_CAT_DIR(i, subset);
1621 if( dir*default_dir < 0 )
1622 fs << i;
1623 }
1624
1625 fs << "]";
1626 }
1627 else
1628 fs << (!split.inversed ? "le" : "gt") << split.c;
1629
1630 fs << "}";
1631 }
1632
writeNode(FileStorage & fs,int nidx,int depth) const1633 void DTreesImpl::writeNode( FileStorage& fs, int nidx, int depth ) const
1634 {
1635 const Node& node = nodes[nidx];
1636 fs << "{";
1637 fs << "depth" << depth;
1638 fs << "value" << node.value;
1639
1640 if( _isClassifier )
1641 fs << "norm_class_idx" << node.classIdx;
1642
1643 if( node.split >= 0 )
1644 {
1645 fs << "splits" << "[";
1646
1647 for( int splitidx = node.split; splitidx >= 0; splitidx = splits[splitidx].next )
1648 writeSplit( fs, splitidx );
1649
1650 fs << "]";
1651 }
1652
1653 fs << "}";
1654 }
1655
writeTree(FileStorage & fs,int root) const1656 void DTreesImpl::writeTree( FileStorage& fs, int root ) const
1657 {
1658 fs << "nodes" << "[";
1659
1660 int nidx = root, pidx = 0, depth = 0;
1661 const Node *node = 0;
1662
1663 // traverse the tree and save all the nodes in depth-first order
1664 for(;;)
1665 {
1666 for(;;)
1667 {
1668 writeNode( fs, nidx, depth );
1669 node = &nodes[nidx];
1670 if( node->left < 0 )
1671 break;
1672 nidx = node->left;
1673 depth++;
1674 }
1675
1676 for( pidx = node->parent; pidx >= 0 && nodes[pidx].right == nidx;
1677 nidx = pidx, pidx = nodes[pidx].parent )
1678 depth--;
1679
1680 if( pidx < 0 )
1681 break;
1682
1683 nidx = nodes[pidx].right;
1684 }
1685
1686 fs << "]";
1687 }
1688
write(FileStorage & fs) const1689 void DTreesImpl::write( FileStorage& fs ) const
1690 {
1691 writeFormat(fs);
1692 writeParams(fs);
1693 writeTree(fs, roots[0]);
1694 }
1695
readParams(const FileNode & fn)1696 void DTreesImpl::readParams( const FileNode& fn )
1697 {
1698 _isClassifier = (int)fn["is_classifier"] != 0;
1699 int varAll = (int)fn["var_all"];
1700 int varCount = (int)fn["var_count"];
1701 /*int cat_var_count = (int)fn["cat_var_count"];
1702 int ord_var_count = (int)fn["ord_var_count"];*/
1703
1704 if (varAll <= 0)
1705 CV_Error(Error::StsParseError, "The field \"var_all\" of DTree classifier is missing or non-positive");
1706
1707 FileNode tparams_node = fn["training_params"];
1708
1709 TreeParams params0 = TreeParams();
1710
1711 if( !tparams_node.empty() ) // training parameters are not necessary
1712 {
1713 params0.useSurrogates = (int)tparams_node["use_surrogates"] != 0;
1714 params0.setMaxCategories((int)(tparams_node["max_categories"].empty() ? 16 : tparams_node["max_categories"]));
1715 params0.setRegressionAccuracy((float)tparams_node["regression_accuracy"]);
1716 params0.setMaxDepth((int)tparams_node["max_depth"]);
1717 params0.setMinSampleCount((int)tparams_node["min_sample_count"]);
1718 params0.setCVFolds((int)tparams_node["cross_validation_folds"]);
1719
1720 if( params0.getCVFolds() > 1 )
1721 {
1722 params.use1SERule = (int)tparams_node["use_1se_rule"] != 0;
1723 }
1724
1725 tparams_node["priors"] >> params0.priors;
1726 }
1727
1728 readVectorOrMat(fn["var_idx"], varIdx);
1729 fn["var_type"] >> varType;
1730
1731 bool isLegacy = false;
1732 if (fn["format"].empty()) // Export bug until OpenCV 3.2: https://github.com/opencv/opencv/pull/6314
1733 {
1734 if (!fn["cat_ofs"].empty())
1735 isLegacy = false; // 2.4 doesn't store "cat_ofs"
1736 else if (!fn["missing_subst"].empty())
1737 isLegacy = false; // 2.4 doesn't store "missing_subst"
1738 else if (!fn["class_labels"].empty())
1739 isLegacy = false; // 2.4 doesn't store "class_labels"
1740 else if ((int)varType.size() != varAll)
1741 isLegacy = true; // 3.0+: https://github.com/opencv/opencv/blame/3.0.0/modules/ml/src/tree.cpp#L1576
1742 else if (/*(int)varType.size() == varAll &&*/ varCount == varAll)
1743 isLegacy = true;
1744 else
1745 {
1746 // 3.0+:
1747 // - https://github.com/opencv/opencv/blame/3.0.0/modules/ml/src/tree.cpp#L1552-L1553
1748 // - https://github.com/opencv/opencv/blame/3.0.0/modules/ml/src/precomp.hpp#L296
1749 isLegacy = !(varCount + 1 == varAll);
1750 }
1751 CV_LOG_INFO(NULL, "ML/DTrees: possible missing 'format' field due to bug of OpenCV export implementation. "
1752 "Details: https://github.com/opencv/opencv/issues/5412. Consider re-exporting of saved ML model. "
1753 "isLegacy = " << isLegacy);
1754 }
1755 else
1756 {
1757 int format = 0;
1758 fn["format"] >> format;
1759 CV_CheckGT(format, 0, "");
1760 isLegacy = format < 3;
1761 }
1762
1763 if (isLegacy && (int)varType.size() <= varAll)
1764 {
1765 std::vector<uchar> extendedTypes(varAll + 1, 0);
1766
1767 int i = 0, n;
1768 if (!varIdx.empty())
1769 {
1770 n = (int)varIdx.size();
1771 for (; i < n; ++i)
1772 {
1773 int var = varIdx[i];
1774 extendedTypes[var] = varType[i];
1775 }
1776 }
1777 else
1778 {
1779 n = (int)varType.size();
1780 for (; i < n; ++i)
1781 {
1782 extendedTypes[i] = varType[i];
1783 }
1784 }
1785 extendedTypes[varAll] = (uchar)(_isClassifier ? VAR_CATEGORICAL : VAR_ORDERED);
1786 extendedTypes.swap(varType);
1787 }
1788
1789 readVectorOrMat(fn["cat_map"], catMap);
1790
1791 if (isLegacy)
1792 {
1793 // generating "catOfs" from "cat_count"
1794 catOfs.clear();
1795 classLabels.clear();
1796 std::vector<int> counts;
1797 readVectorOrMat(fn["cat_count"], counts);
1798 unsigned int i = 0, j = 0, curShift = 0, size = (int)varType.size() - 1;
1799 for (; i < size; ++i)
1800 {
1801 Vec2i newOffsets(0, 0);
1802 if (varType[i] == VAR_CATEGORICAL) // only categorical vars are represented in catMap
1803 {
1804 newOffsets[0] = curShift;
1805 curShift += counts[j];
1806 newOffsets[1] = curShift;
1807 ++j;
1808 }
1809 catOfs.push_back(newOffsets);
1810 }
1811 // other elements in "catMap" are "classLabels"
1812 if (curShift < catMap.size())
1813 {
1814 classLabels.insert(classLabels.end(), catMap.begin() + curShift, catMap.end());
1815 catMap.erase(catMap.begin() + curShift, catMap.end());
1816 }
1817 }
1818 else
1819 {
1820 fn["cat_ofs"] >> catOfs;
1821 fn["missing_subst"] >> missingSubst;
1822 fn["class_labels"] >> classLabels;
1823 }
1824
1825 // init var mapping for node reading (var indexes or varIdx indexes)
1826 bool globalVarIdx = false;
1827 fn["global_var_idx"] >> globalVarIdx;
1828 if (globalVarIdx || varIdx.empty())
1829 setRangeVector(varMapping, (int)varType.size());
1830 else
1831 varMapping = varIdx;
1832
1833 initCompVarIdx();
1834 setDParams(params0);
1835 }
1836
readSplit(const FileNode & fn)1837 int DTreesImpl::readSplit( const FileNode& fn )
1838 {
1839 Split split;
1840
1841 int vi = (int)fn["var"];
1842 CV_Assert( 0 <= vi && vi <= (int)varType.size() );
1843 vi = varMapping[vi]; // convert to varIdx if needed
1844 split.varIdx = vi;
1845
1846 if( varType[vi] == VAR_CATEGORICAL ) // split on categorical var
1847 {
1848 int i, val, ssize = getSubsetSize(vi);
1849 split.subsetOfs = (int)subsets.size();
1850 for( i = 0; i < ssize; i++ )
1851 subsets.push_back(0);
1852 int* subset = &subsets[split.subsetOfs];
1853 FileNode fns = fn["in"];
1854 if( fns.empty() )
1855 {
1856 fns = fn["not_in"];
1857 split.inversed = true;
1858 }
1859
1860 if( fns.isInt() )
1861 {
1862 val = (int)fns;
1863 subset[val >> 5] |= 1 << (val & 31);
1864 }
1865 else
1866 {
1867 FileNodeIterator it = fns.begin();
1868 int n = (int)fns.size();
1869 for( i = 0; i < n; i++, ++it )
1870 {
1871 val = (int)*it;
1872 subset[val >> 5] |= 1 << (val & 31);
1873 }
1874 }
1875
1876 // for categorical splits we do not use inversed splits,
1877 // instead we inverse the variable set in the split
1878 if( split.inversed )
1879 {
1880 for( i = 0; i < ssize; i++ )
1881 subset[i] ^= -1;
1882 split.inversed = false;
1883 }
1884 }
1885 else
1886 {
1887 FileNode cmpNode = fn["le"];
1888 if( cmpNode.empty() )
1889 {
1890 cmpNode = fn["gt"];
1891 split.inversed = true;
1892 }
1893 split.c = (float)cmpNode;
1894 }
1895
1896 split.quality = (float)fn["quality"];
1897 splits.push_back(split);
1898
1899 return (int)(splits.size() - 1);
1900 }
1901
readNode(const FileNode & fn)1902 int DTreesImpl::readNode( const FileNode& fn )
1903 {
1904 Node node;
1905 node.value = (double)fn["value"];
1906
1907 if( _isClassifier )
1908 node.classIdx = (int)fn["norm_class_idx"];
1909
1910 FileNode sfn = fn["splits"];
1911 if( !sfn.empty() )
1912 {
1913 int i, n = (int)sfn.size(), prevsplit = -1;
1914 FileNodeIterator it = sfn.begin();
1915
1916 for( i = 0; i < n; i++, ++it )
1917 {
1918 int splitidx = readSplit(*it);
1919 if( splitidx < 0 )
1920 break;
1921 if( prevsplit < 0 )
1922 node.split = splitidx;
1923 else
1924 splits[prevsplit].next = splitidx;
1925 prevsplit = splitidx;
1926 }
1927 }
1928 nodes.push_back(node);
1929 return (int)(nodes.size() - 1);
1930 }
1931
readTree(const FileNode & fn)1932 int DTreesImpl::readTree( const FileNode& fn )
1933 {
1934 int i, n = (int)fn.size(), root = -1, pidx = -1;
1935 FileNodeIterator it = fn.begin();
1936
1937 for( i = 0; i < n; i++, ++it )
1938 {
1939 int nidx = readNode(*it);
1940 if( nidx < 0 )
1941 break;
1942 Node& node = nodes[nidx];
1943 node.parent = pidx;
1944 if( pidx < 0 )
1945 root = nidx;
1946 else
1947 {
1948 Node& parent = nodes[pidx];
1949 if( parent.left < 0 )
1950 parent.left = nidx;
1951 else
1952 parent.right = nidx;
1953 }
1954 if( node.split >= 0 )
1955 pidx = nidx;
1956 else
1957 {
1958 while( pidx >= 0 && nodes[pidx].right >= 0 )
1959 pidx = nodes[pidx].parent;
1960 }
1961 }
1962 roots.push_back(root);
1963 return root;
1964 }
1965
read(const FileNode & fn)1966 void DTreesImpl::read( const FileNode& fn )
1967 {
1968 clear();
1969 readParams(fn);
1970
1971 FileNode fnodes = fn["nodes"];
1972 CV_Assert( !fnodes.empty() );
1973 readTree(fnodes);
1974 }
1975
create()1976 Ptr<DTrees> DTrees::create()
1977 {
1978 return makePtr<DTreesImpl>();
1979 }
1980
load(const String & filepath,const String & nodeName)1981 Ptr<DTrees> DTrees::load(const String& filepath, const String& nodeName)
1982 {
1983 return Algorithm::load<DTrees>(filepath, nodeName);
1984 }
1985
1986
1987 }
1988 }
1989
1990 /* End of file. */
1991