1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 // By downloading, copying, installing or using the software you agree to this license.
6 // If you do not agree to this license, do not download, install,
7 // copy or use the software.
8 //
9 //
10 // Intel License Agreement
11 // For Open Source Computer Vision Library
12 //
13 // Copyright( C) 2000, Intel Corporation, all rights reserved.
14 // Third party copyrights are property of their respective owners.
15 //
16 // Redistribution and use in source and binary forms, with or without modification,
17 // are permitted provided that the following conditions are met:
18 //
19 // * Redistribution's of source code must retain the above copyright notice,
20 // this list of conditions and the following disclaimer.
21 //
22 // * Redistribution's in binary form must reproduce the above copyright notice,
23 // this list of conditions and the following disclaimer in the documentation
24 // and/or other materials provided with the distribution.
25 //
26 // * The name of Intel Corporation may not be used to endorse or promote products
27 // derived from this software without specific prior written permission.
28 //
29 // This software is provided by the copyright holders and contributors "as is" and
30 // any express or implied warranties, including, but not limited to, the implied
31 // warranties of merchantability and fitness for a particular purpose are disclaimed.
32 // In no event shall the Intel Corporation or contributors be liable for any direct,
33 // indirect, incidental, special, exemplary, or consequential damages
34 //(including, but not limited to, procurement of substitute goods or services;
35 // loss of use, data, or profits; or business interruption) however caused
36 // and on any theory of liability, whether in contract, strict liability,
37 // or tort(including negligence or otherwise) arising in any way out of
38 // the use of this software, even ifadvised of the possibility of such damage.
39 //
40 //M*/
41
42 #include "precomp.hpp"
43
44 namespace cv
45 {
46 namespace ml
47 {
48
49 const double minEigenValue = DBL_EPSILON;
50
51 class CV_EXPORTS EMImpl CV_FINAL : public EM
52 {
53 public:
54
55 int nclusters;
56 int covMatType;
57 TermCriteria termCrit;
58
getTermCriteria() const59 inline TermCriteria getTermCriteria() const CV_OVERRIDE { return termCrit; }
setTermCriteria(const TermCriteria & val)60 inline void setTermCriteria(const TermCriteria& val) CV_OVERRIDE { termCrit = val; }
61
setClustersNumber(int val)62 void setClustersNumber(int val) CV_OVERRIDE
63 {
64 nclusters = val;
65 CV_Assert(nclusters >= 1);
66 }
67
getClustersNumber() const68 int getClustersNumber() const CV_OVERRIDE
69 {
70 return nclusters;
71 }
72
setCovarianceMatrixType(int val)73 void setCovarianceMatrixType(int val) CV_OVERRIDE
74 {
75 covMatType = val;
76 CV_Assert(covMatType == COV_MAT_SPHERICAL ||
77 covMatType == COV_MAT_DIAGONAL ||
78 covMatType == COV_MAT_GENERIC);
79 }
80
getCovarianceMatrixType() const81 int getCovarianceMatrixType() const CV_OVERRIDE
82 {
83 return covMatType;
84 }
85
EMImpl()86 EMImpl()
87 {
88 nclusters = DEFAULT_NCLUSTERS;
89 covMatType=EM::COV_MAT_DIAGONAL;
90 termCrit = TermCriteria(TermCriteria::COUNT+TermCriteria::EPS, EM::DEFAULT_MAX_ITERS, 1e-6);
91 }
92
~EMImpl()93 virtual ~EMImpl() {}
94
clear()95 void clear() CV_OVERRIDE
96 {
97 trainSamples.release();
98 trainProbs.release();
99 trainLogLikelihoods.release();
100 trainLabels.release();
101
102 weights.release();
103 means.release();
104 covs.clear();
105
106 covsEigenValues.clear();
107 invCovsEigenValues.clear();
108 covsRotateMats.clear();
109
110 logWeightDivDet.release();
111 }
112
train(const Ptr<TrainData> & data,int)113 bool train(const Ptr<TrainData>& data, int) CV_OVERRIDE
114 {
115 CV_Assert(!data.empty());
116 Mat samples = data->getTrainSamples(), labels;
117 return trainEM(samples, labels, noArray(), noArray());
118 }
119
trainEM(InputArray samples,OutputArray logLikelihoods,OutputArray labels,OutputArray probs)120 bool trainEM(InputArray samples,
121 OutputArray logLikelihoods,
122 OutputArray labels,
123 OutputArray probs) CV_OVERRIDE
124 {
125 Mat samplesMat = samples.getMat();
126 setTrainData(START_AUTO_STEP, samplesMat, 0, 0, 0, 0);
127 return doTrain(START_AUTO_STEP, logLikelihoods, labels, probs);
128 }
129
trainE(InputArray samples,InputArray _means0,InputArray _covs0,InputArray _weights0,OutputArray logLikelihoods,OutputArray labels,OutputArray probs)130 bool trainE(InputArray samples,
131 InputArray _means0,
132 InputArray _covs0,
133 InputArray _weights0,
134 OutputArray logLikelihoods,
135 OutputArray labels,
136 OutputArray probs) CV_OVERRIDE
137 {
138 Mat samplesMat = samples.getMat();
139 std::vector<Mat> covs0;
140 _covs0.getMatVector(covs0);
141
142 Mat means0 = _means0.getMat(), weights0 = _weights0.getMat();
143
144 setTrainData(START_E_STEP, samplesMat, 0, !_means0.empty() ? &means0 : 0,
145 !_covs0.empty() ? &covs0 : 0, !_weights0.empty() ? &weights0 : 0);
146 return doTrain(START_E_STEP, logLikelihoods, labels, probs);
147 }
148
trainM(InputArray samples,InputArray _probs0,OutputArray logLikelihoods,OutputArray labels,OutputArray probs)149 bool trainM(InputArray samples,
150 InputArray _probs0,
151 OutputArray logLikelihoods,
152 OutputArray labels,
153 OutputArray probs) CV_OVERRIDE
154 {
155 Mat samplesMat = samples.getMat();
156 Mat probs0 = _probs0.getMat();
157
158 setTrainData(START_M_STEP, samplesMat, !_probs0.empty() ? &probs0 : 0, 0, 0, 0);
159 return doTrain(START_M_STEP, logLikelihoods, labels, probs);
160 }
161
predict(InputArray _inputs,OutputArray _outputs,int) const162 float predict(InputArray _inputs, OutputArray _outputs, int) const CV_OVERRIDE
163 {
164 bool needprobs = _outputs.needed();
165 Mat samples = _inputs.getMat(), probs, probsrow;
166 int ptype = CV_64F;
167 float firstres = 0.f;
168 int i, nsamples = samples.rows;
169
170 if( needprobs )
171 {
172 if( _outputs.fixedType() )
173 ptype = _outputs.type();
174 _outputs.create(samples.rows, nclusters, ptype);
175 probs = _outputs.getMat();
176 }
177 else
178 nsamples = std::min(nsamples, 1);
179
180 for( i = 0; i < nsamples; i++ )
181 {
182 if( needprobs )
183 probsrow = probs.row(i);
184 Vec2d res = computeProbabilities(samples.row(i), needprobs ? &probsrow : 0, ptype);
185 if( i == 0 )
186 firstres = (float)res[1];
187 }
188 return firstres;
189 }
190
predict2(InputArray _sample,OutputArray _probs) const191 Vec2d predict2(InputArray _sample, OutputArray _probs) const CV_OVERRIDE
192 {
193 int ptype = CV_64F;
194 Mat sample = _sample.getMat();
195 CV_Assert(isTrained());
196
197 CV_Assert(!sample.empty());
198 if(sample.type() != CV_64FC1)
199 {
200 Mat tmp;
201 sample.convertTo(tmp, CV_64FC1);
202 sample = tmp;
203 }
204 sample = sample.reshape(1, 1);
205
206 Mat probs;
207 if( _probs.needed() )
208 {
209 if( _probs.fixedType() )
210 ptype = _probs.type();
211 _probs.create(1, nclusters, ptype);
212 probs = _probs.getMat();
213 }
214
215 return computeProbabilities(sample, !probs.empty() ? &probs : 0, ptype);
216 }
217
isTrained() const218 bool isTrained() const CV_OVERRIDE
219 {
220 return !means.empty();
221 }
222
isClassifier() const223 bool isClassifier() const CV_OVERRIDE
224 {
225 return true;
226 }
227
getVarCount() const228 int getVarCount() const CV_OVERRIDE
229 {
230 return means.cols;
231 }
232
getDefaultName() const233 String getDefaultName() const CV_OVERRIDE
234 {
235 return "opencv_ml_em";
236 }
237
checkTrainData(int startStep,const Mat & samples,int nclusters,int covMatType,const Mat * probs,const Mat * means,const std::vector<Mat> * covs,const Mat * weights)238 static void checkTrainData(int startStep, const Mat& samples,
239 int nclusters, int covMatType, const Mat* probs, const Mat* means,
240 const std::vector<Mat>* covs, const Mat* weights)
241 {
242 // Check samples.
243 CV_Assert(!samples.empty());
244 CV_Assert(samples.channels() == 1);
245
246 int nsamples = samples.rows;
247 int dim = samples.cols;
248
249 // Check training params.
250 CV_Assert(nclusters > 0);
251 CV_Assert(nclusters <= nsamples);
252 CV_Assert(startStep == START_AUTO_STEP ||
253 startStep == START_E_STEP ||
254 startStep == START_M_STEP);
255 CV_Assert(covMatType == COV_MAT_GENERIC ||
256 covMatType == COV_MAT_DIAGONAL ||
257 covMatType == COV_MAT_SPHERICAL);
258
259 CV_Assert(!probs ||
260 (!probs->empty() &&
261 probs->rows == nsamples && probs->cols == nclusters &&
262 (probs->type() == CV_32FC1 || probs->type() == CV_64FC1)));
263
264 CV_Assert(!weights ||
265 (!weights->empty() &&
266 (weights->cols == 1 || weights->rows == 1) && static_cast<int>(weights->total()) == nclusters &&
267 (weights->type() == CV_32FC1 || weights->type() == CV_64FC1)));
268
269 CV_Assert(!means ||
270 (!means->empty() &&
271 means->rows == nclusters && means->cols == dim &&
272 means->channels() == 1));
273
274 CV_Assert(!covs ||
275 (!covs->empty() &&
276 static_cast<int>(covs->size()) == nclusters));
277 if(covs)
278 {
279 const Size covSize(dim, dim);
280 for(size_t i = 0; i < covs->size(); i++)
281 {
282 const Mat& m = (*covs)[i];
283 CV_Assert(!m.empty() && m.size() == covSize && (m.channels() == 1));
284 }
285 }
286
287 if(startStep == START_E_STEP)
288 {
289 CV_Assert(means);
290 }
291 else if(startStep == START_M_STEP)
292 {
293 CV_Assert(probs);
294 }
295 }
296
preprocessSampleData(const Mat & src,Mat & dst,int dstType,bool isAlwaysClone)297 static void preprocessSampleData(const Mat& src, Mat& dst, int dstType, bool isAlwaysClone)
298 {
299 if(src.type() == dstType && !isAlwaysClone)
300 dst = src;
301 else
302 src.convertTo(dst, dstType);
303 }
304
preprocessProbability(Mat & probs)305 static void preprocessProbability(Mat& probs)
306 {
307 max(probs, 0., probs);
308
309 const double uniformProbability = (double)(1./probs.cols);
310 for(int y = 0; y < probs.rows; y++)
311 {
312 Mat sampleProbs = probs.row(y);
313
314 double maxVal = 0;
315 minMaxLoc(sampleProbs, 0, &maxVal);
316 if(maxVal < FLT_EPSILON)
317 sampleProbs.setTo(uniformProbability);
318 else
319 normalize(sampleProbs, sampleProbs, 1, 0, NORM_L1);
320 }
321 }
322
setTrainData(int startStep,const Mat & samples,const Mat * probs0,const Mat * means0,const std::vector<Mat> * covs0,const Mat * weights0)323 void setTrainData(int startStep, const Mat& samples,
324 const Mat* probs0,
325 const Mat* means0,
326 const std::vector<Mat>* covs0,
327 const Mat* weights0)
328 {
329 clear();
330
331 checkTrainData(startStep, samples, nclusters, covMatType, probs0, means0, covs0, weights0);
332
333 bool isKMeansInit = (startStep == START_AUTO_STEP) || (startStep == START_E_STEP && (covs0 == 0 || weights0 == 0));
334 // Set checked data
335 preprocessSampleData(samples, trainSamples, isKMeansInit ? CV_32FC1 : CV_64FC1, false);
336
337 // set probs
338 if(probs0 && startStep == START_M_STEP)
339 {
340 preprocessSampleData(*probs0, trainProbs, CV_64FC1, true);
341 preprocessProbability(trainProbs);
342 }
343
344 // set weights
345 if(weights0 && (startStep == START_E_STEP && covs0))
346 {
347 weights0->convertTo(weights, CV_64FC1);
348 weights = weights.reshape(1,1);
349 preprocessProbability(weights);
350 }
351
352 // set means
353 if(means0 && (startStep == START_E_STEP/* || startStep == START_AUTO_STEP*/))
354 means0->convertTo(means, isKMeansInit ? CV_32FC1 : CV_64FC1);
355
356 // set covs
357 if(covs0 && (startStep == START_E_STEP && weights0))
358 {
359 covs.resize(nclusters);
360 for(size_t i = 0; i < covs0->size(); i++)
361 (*covs0)[i].convertTo(covs[i], CV_64FC1);
362 }
363 }
364
decomposeCovs()365 void decomposeCovs()
366 {
367 CV_Assert(!covs.empty());
368 covsEigenValues.resize(nclusters);
369 if(covMatType == COV_MAT_GENERIC)
370 covsRotateMats.resize(nclusters);
371 invCovsEigenValues.resize(nclusters);
372 for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
373 {
374 CV_Assert(!covs[clusterIndex].empty());
375
376 SVD svd(covs[clusterIndex], SVD::MODIFY_A + SVD::FULL_UV);
377
378 if(covMatType == COV_MAT_SPHERICAL)
379 {
380 double maxSingularVal = svd.w.at<double>(0);
381 covsEigenValues[clusterIndex] = Mat(1, 1, CV_64FC1, Scalar(maxSingularVal));
382 }
383 else if(covMatType == COV_MAT_DIAGONAL)
384 {
385 covsEigenValues[clusterIndex] = covs[clusterIndex].diag().clone(); //Preserve the original order of eigen values.
386 }
387 else //COV_MAT_GENERIC
388 {
389 covsEigenValues[clusterIndex] = svd.w;
390 covsRotateMats[clusterIndex] = svd.u;
391 }
392 max(covsEigenValues[clusterIndex], minEigenValue, covsEigenValues[clusterIndex]);
393 invCovsEigenValues[clusterIndex] = 1./covsEigenValues[clusterIndex];
394 }
395 }
396
clusterTrainSamples()397 void clusterTrainSamples()
398 {
399 int nsamples = trainSamples.rows;
400
401 // Cluster samples, compute/update means
402
403 // Convert samples and means to 32F, because kmeans requires this type.
404 Mat trainSamplesFlt, meansFlt;
405 if(trainSamples.type() != CV_32FC1)
406 trainSamples.convertTo(trainSamplesFlt, CV_32FC1);
407 else
408 trainSamplesFlt = trainSamples;
409 if(!means.empty())
410 {
411 if(means.type() != CV_32FC1)
412 means.convertTo(meansFlt, CV_32FC1);
413 else
414 meansFlt = means;
415 }
416
417 Mat labels;
418 kmeans(trainSamplesFlt, nclusters, labels,
419 TermCriteria(TermCriteria::COUNT, means.empty() ? 10 : 1, 0.5),
420 10, KMEANS_PP_CENTERS, meansFlt);
421
422 // Convert samples and means back to 64F.
423 CV_Assert(meansFlt.type() == CV_32FC1);
424 if(trainSamples.type() != CV_64FC1)
425 {
426 Mat trainSamplesBuffer;
427 trainSamplesFlt.convertTo(trainSamplesBuffer, CV_64FC1);
428 trainSamples = trainSamplesBuffer;
429 }
430 meansFlt.convertTo(means, CV_64FC1);
431
432 // Compute weights and covs
433 weights = Mat(1, nclusters, CV_64FC1, Scalar(0));
434 covs.resize(nclusters);
435 for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
436 {
437 Mat clusterSamples;
438 for(int sampleIndex = 0; sampleIndex < nsamples; sampleIndex++)
439 {
440 if(labels.at<int>(sampleIndex) == clusterIndex)
441 {
442 const Mat sample = trainSamples.row(sampleIndex);
443 clusterSamples.push_back(sample);
444 }
445 }
446 CV_Assert(!clusterSamples.empty());
447
448 calcCovarMatrix(clusterSamples, covs[clusterIndex], means.row(clusterIndex),
449 CV_COVAR_NORMAL + CV_COVAR_ROWS + CV_COVAR_USE_AVG + CV_COVAR_SCALE, CV_64FC1);
450 weights.at<double>(clusterIndex) = static_cast<double>(clusterSamples.rows)/static_cast<double>(nsamples);
451 }
452
453 decomposeCovs();
454 }
455
computeLogWeightDivDet()456 void computeLogWeightDivDet()
457 {
458 CV_Assert(!covsEigenValues.empty());
459
460 Mat logWeights;
461 cv::max(weights, DBL_MIN, weights);
462 log(weights, logWeights);
463
464 logWeightDivDet.create(1, nclusters, CV_64FC1);
465 // note: logWeightDivDet = log(weight_k) - 0.5 * log(|det(cov_k)|)
466
467 for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
468 {
469 double logDetCov = 0.;
470 const int evalCount = static_cast<int>(covsEigenValues[clusterIndex].total());
471 for(int di = 0; di < evalCount; di++)
472 logDetCov += std::log(covsEigenValues[clusterIndex].at<double>(covMatType != COV_MAT_SPHERICAL ? di : 0));
473
474 logWeightDivDet.at<double>(clusterIndex) = logWeights.at<double>(clusterIndex) - 0.5 * logDetCov;
475 }
476 }
477
doTrain(int startStep,OutputArray logLikelihoods,OutputArray labels,OutputArray probs)478 bool doTrain(int startStep, OutputArray logLikelihoods, OutputArray labels, OutputArray probs)
479 {
480 int dim = trainSamples.cols;
481 // Precompute the empty initial train data in the cases of START_E_STEP and START_AUTO_STEP
482 if(startStep != START_M_STEP)
483 {
484 if(covs.empty())
485 {
486 CV_Assert(weights.empty());
487 clusterTrainSamples();
488 }
489 }
490
491 if(!covs.empty() && covsEigenValues.empty() )
492 {
493 CV_Assert(invCovsEigenValues.empty());
494 decomposeCovs();
495 }
496
497 if(startStep == START_M_STEP)
498 mStep();
499
500 double trainLogLikelihood, prevTrainLogLikelihood = 0.;
501 int maxIters = (termCrit.type & TermCriteria::MAX_ITER) ?
502 termCrit.maxCount : DEFAULT_MAX_ITERS;
503 double epsilon = (termCrit.type & TermCriteria::EPS) ? termCrit.epsilon : 0.;
504
505 for(int iter = 0; ; iter++)
506 {
507 eStep();
508 trainLogLikelihood = sum(trainLogLikelihoods)[0];
509
510 if(iter >= maxIters - 1)
511 break;
512
513 double trainLogLikelihoodDelta = trainLogLikelihood - prevTrainLogLikelihood;
514 if( iter != 0 &&
515 (trainLogLikelihoodDelta < -DBL_EPSILON ||
516 trainLogLikelihoodDelta < epsilon * std::fabs(trainLogLikelihood)))
517 break;
518
519 mStep();
520
521 prevTrainLogLikelihood = trainLogLikelihood;
522 }
523
524 if( trainLogLikelihood <= -DBL_MAX/10000. )
525 {
526 clear();
527 return false;
528 }
529
530 // postprocess covs
531 covs.resize(nclusters);
532 for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
533 {
534 if(covMatType == COV_MAT_SPHERICAL)
535 {
536 covs[clusterIndex].create(dim, dim, CV_64FC1);
537 setIdentity(covs[clusterIndex], Scalar(covsEigenValues[clusterIndex].at<double>(0)));
538 }
539 else if(covMatType == COV_MAT_DIAGONAL)
540 {
541 covs[clusterIndex] = Mat::diag(covsEigenValues[clusterIndex]);
542 }
543 }
544
545 if(labels.needed())
546 trainLabels.copyTo(labels);
547 if(probs.needed())
548 trainProbs.copyTo(probs);
549 if(logLikelihoods.needed())
550 trainLogLikelihoods.copyTo(logLikelihoods);
551
552 trainSamples.release();
553 trainProbs.release();
554 trainLabels.release();
555 trainLogLikelihoods.release();
556
557 return true;
558 }
559
computeProbabilities(const Mat & sample,Mat * probs,int ptype) const560 Vec2d computeProbabilities(const Mat& sample, Mat* probs, int ptype) const
561 {
562 // L_ik = log(weight_k) - 0.5 * log(|det(cov_k)|) - 0.5 *(x_i - mean_k)' cov_k^(-1) (x_i - mean_k)]
563 // q = arg(max_k(L_ik))
564 // probs_ik = exp(L_ik - L_iq) / (1 + sum_j!=q (exp(L_ij - L_iq))
565 // see Alex Smola's blog http://blog.smola.org/page/2 for
566 // details on the log-sum-exp trick
567
568 int stype = sample.type();
569 CV_Assert(!means.empty());
570 CV_Assert((stype == CV_32F || stype == CV_64F) && (ptype == CV_32F || ptype == CV_64F));
571 CV_Assert(sample.size() == Size(means.cols, 1));
572
573 int dim = sample.cols;
574
575 Mat L(1, nclusters, CV_64FC1), centeredSample(1, dim, CV_64F);
576 int i, label = 0;
577 for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
578 {
579 const double* mptr = means.ptr<double>(clusterIndex);
580 double* dptr = centeredSample.ptr<double>();
581 if( stype == CV_32F )
582 {
583 const float* sptr = sample.ptr<float>();
584 for( i = 0; i < dim; i++ )
585 dptr[i] = sptr[i] - mptr[i];
586 }
587 else
588 {
589 const double* sptr = sample.ptr<double>();
590 for( i = 0; i < dim; i++ )
591 dptr[i] = sptr[i] - mptr[i];
592 }
593
594 Mat rotatedCenteredSample = covMatType != COV_MAT_GENERIC ?
595 centeredSample : centeredSample * covsRotateMats[clusterIndex];
596
597 double Lval = 0;
598 for(int di = 0; di < dim; di++)
599 {
600 double w = invCovsEigenValues[clusterIndex].at<double>(covMatType != COV_MAT_SPHERICAL ? di : 0);
601 double val = rotatedCenteredSample.at<double>(di);
602 Lval += w * val * val;
603 }
604 CV_DbgAssert(!logWeightDivDet.empty());
605 L.at<double>(clusterIndex) = logWeightDivDet.at<double>(clusterIndex) - 0.5 * Lval;
606
607 if(L.at<double>(clusterIndex) > L.at<double>(label))
608 label = clusterIndex;
609 }
610
611 double maxLVal = L.at<double>(label);
612 double expDiffSum = 0;
613 for( i = 0; i < L.cols; i++ )
614 {
615 double v = std::exp(L.at<double>(i) - maxLVal);
616 L.at<double>(i) = v;
617 expDiffSum += v; // sum_j(exp(L_ij - L_iq))
618 }
619
620 CV_Assert(expDiffSum > 0);
621 if(probs)
622 L.convertTo(*probs, ptype, 1./expDiffSum);
623
624 Vec2d res;
625 res[0] = std::log(expDiffSum) + maxLVal - 0.5 * dim * CV_LOG2PI;
626 res[1] = label;
627
628 return res;
629 }
630
eStep()631 void eStep()
632 {
633 // Compute probs_ik from means_k, covs_k and weights_k.
634 trainProbs.create(trainSamples.rows, nclusters, CV_64FC1);
635 trainLabels.create(trainSamples.rows, 1, CV_32SC1);
636 trainLogLikelihoods.create(trainSamples.rows, 1, CV_64FC1);
637
638 computeLogWeightDivDet();
639
640 CV_DbgAssert(trainSamples.type() == CV_64FC1);
641 CV_DbgAssert(means.type() == CV_64FC1);
642
643 for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
644 {
645 Mat sampleProbs = trainProbs.row(sampleIndex);
646 Vec2d res = computeProbabilities(trainSamples.row(sampleIndex), &sampleProbs, CV_64F);
647 trainLogLikelihoods.at<double>(sampleIndex) = res[0];
648 trainLabels.at<int>(sampleIndex) = static_cast<int>(res[1]);
649 }
650 }
651
mStep()652 void mStep()
653 {
654 // Update means_k, covs_k and weights_k from probs_ik
655 int dim = trainSamples.cols;
656
657 // Update weights
658 // not normalized first
659 reduce(trainProbs, weights, 0, CV_REDUCE_SUM);
660
661 // Update means
662 means.create(nclusters, dim, CV_64FC1);
663 means = Scalar(0);
664
665 const double minPosWeight = trainSamples.rows * DBL_EPSILON;
666 double minWeight = DBL_MAX;
667 int minWeightClusterIndex = -1;
668 for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
669 {
670 if(weights.at<double>(clusterIndex) <= minPosWeight)
671 continue;
672
673 if(weights.at<double>(clusterIndex) < minWeight)
674 {
675 minWeight = weights.at<double>(clusterIndex);
676 minWeightClusterIndex = clusterIndex;
677 }
678
679 Mat clusterMean = means.row(clusterIndex);
680 for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
681 clusterMean += trainProbs.at<double>(sampleIndex, clusterIndex) * trainSamples.row(sampleIndex);
682 clusterMean /= weights.at<double>(clusterIndex);
683 }
684
685 // Update covsEigenValues and invCovsEigenValues
686 covs.resize(nclusters);
687 covsEigenValues.resize(nclusters);
688 if(covMatType == COV_MAT_GENERIC)
689 covsRotateMats.resize(nclusters);
690 invCovsEigenValues.resize(nclusters);
691 for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
692 {
693 if(weights.at<double>(clusterIndex) <= minPosWeight)
694 continue;
695
696 if(covMatType != COV_MAT_SPHERICAL)
697 covsEigenValues[clusterIndex].create(1, dim, CV_64FC1);
698 else
699 covsEigenValues[clusterIndex].create(1, 1, CV_64FC1);
700
701 if(covMatType == COV_MAT_GENERIC)
702 covs[clusterIndex].create(dim, dim, CV_64FC1);
703
704 Mat clusterCov = covMatType != COV_MAT_GENERIC ?
705 covsEigenValues[clusterIndex] : covs[clusterIndex];
706
707 clusterCov = Scalar(0);
708
709 Mat centeredSample;
710 for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
711 {
712 centeredSample = trainSamples.row(sampleIndex) - means.row(clusterIndex);
713
714 if(covMatType == COV_MAT_GENERIC)
715 clusterCov += trainProbs.at<double>(sampleIndex, clusterIndex) * centeredSample.t() * centeredSample;
716 else
717 {
718 double p = trainProbs.at<double>(sampleIndex, clusterIndex);
719 for(int di = 0; di < dim; di++ )
720 {
721 double val = centeredSample.at<double>(di);
722 clusterCov.at<double>(covMatType != COV_MAT_SPHERICAL ? di : 0) += p*val*val;
723 }
724 }
725 }
726
727 if(covMatType == COV_MAT_SPHERICAL)
728 clusterCov /= dim;
729
730 clusterCov /= weights.at<double>(clusterIndex);
731
732 // Update covsRotateMats for COV_MAT_GENERIC only
733 if(covMatType == COV_MAT_GENERIC)
734 {
735 SVD svd(covs[clusterIndex], SVD::MODIFY_A + SVD::FULL_UV);
736 covsEigenValues[clusterIndex] = svd.w;
737 covsRotateMats[clusterIndex] = svd.u;
738 }
739
740 max(covsEigenValues[clusterIndex], minEigenValue, covsEigenValues[clusterIndex]);
741
742 // update invCovsEigenValues
743 invCovsEigenValues[clusterIndex] = 1./covsEigenValues[clusterIndex];
744 }
745
746 for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
747 {
748 if(weights.at<double>(clusterIndex) <= minPosWeight)
749 {
750 Mat clusterMean = means.row(clusterIndex);
751 means.row(minWeightClusterIndex).copyTo(clusterMean);
752 covs[minWeightClusterIndex].copyTo(covs[clusterIndex]);
753 covsEigenValues[minWeightClusterIndex].copyTo(covsEigenValues[clusterIndex]);
754 if(covMatType == COV_MAT_GENERIC)
755 covsRotateMats[minWeightClusterIndex].copyTo(covsRotateMats[clusterIndex]);
756 invCovsEigenValues[minWeightClusterIndex].copyTo(invCovsEigenValues[clusterIndex]);
757 }
758 }
759
760 // Normalize weights
761 weights /= trainSamples.rows;
762 }
763
write_params(FileStorage & fs) const764 void write_params(FileStorage& fs) const
765 {
766 fs << "nclusters" << nclusters;
767 fs << "cov_mat_type" << (covMatType == COV_MAT_SPHERICAL ? String("spherical") :
768 covMatType == COV_MAT_DIAGONAL ? String("diagonal") :
769 covMatType == COV_MAT_GENERIC ? String("generic") :
770 format("unknown_%d", covMatType));
771 writeTermCrit(fs, termCrit);
772 }
773
write(FileStorage & fs) const774 void write(FileStorage& fs) const CV_OVERRIDE
775 {
776 writeFormat(fs);
777 fs << "training_params" << "{";
778 write_params(fs);
779 fs << "}";
780 fs << "weights" << weights;
781 fs << "means" << means;
782
783 size_t i, n = covs.size();
784
785 fs << "covs" << "[";
786 for( i = 0; i < n; i++ )
787 fs << covs[i];
788 fs << "]";
789 }
790
read_params(const FileNode & fn)791 void read_params(const FileNode& fn)
792 {
793 nclusters = (int)fn["nclusters"];
794 String s = (String)fn["cov_mat_type"];
795 covMatType = s == "spherical" ? COV_MAT_SPHERICAL :
796 s == "diagonal" ? COV_MAT_DIAGONAL :
797 s == "generic" ? COV_MAT_GENERIC : -1;
798 CV_Assert(covMatType >= 0);
799 termCrit = readTermCrit(fn);
800 }
801
read(const FileNode & fn)802 void read(const FileNode& fn) CV_OVERRIDE
803 {
804 clear();
805 read_params(fn["training_params"]);
806
807 fn["weights"] >> weights;
808 fn["means"] >> means;
809
810 FileNode cfn = fn["covs"];
811 FileNodeIterator cfn_it = cfn.begin();
812 int i, n = (int)cfn.size();
813 covs.resize(n);
814
815 for( i = 0; i < n; i++, ++cfn_it )
816 (*cfn_it) >> covs[i];
817
818 decomposeCovs();
819 computeLogWeightDivDet();
820 }
821
getWeights() const822 Mat getWeights() const CV_OVERRIDE { return weights; }
getMeans() const823 Mat getMeans() const CV_OVERRIDE { return means; }
getCovs(std::vector<Mat> & _covs) const824 void getCovs(std::vector<Mat>& _covs) const CV_OVERRIDE
825 {
826 _covs.resize(covs.size());
827 std::copy(covs.begin(), covs.end(), _covs.begin());
828 }
829
830 // all inner matrices have type CV_64FC1
831 Mat trainSamples;
832 Mat trainProbs;
833 Mat trainLogLikelihoods;
834 Mat trainLabels;
835
836 Mat weights;
837 Mat means;
838 std::vector<Mat> covs;
839
840 std::vector<Mat> covsEigenValues;
841 std::vector<Mat> covsRotateMats;
842 std::vector<Mat> invCovsEigenValues;
843 Mat logWeightDivDet;
844 };
845
create()846 Ptr<EM> EM::create()
847 {
848 return makePtr<EMImpl>();
849 }
850
load(const String & filepath,const String & nodeName)851 Ptr<EM> EM::load(const String& filepath, const String& nodeName)
852 {
853 return Algorithm::load<EM>(filepath, nodeName);
854 }
855
856 }
857 } // namespace cv
858
859 /* End of file. */
860