1 /*M///////////////////////////////////////////////////////////////////////////////////////
2 //
3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4 //
5 //  By downloading, copying, installing or using the software you agree to this license.
6 //  If you do not agree to this license, do not download, install,
7 //  copy or use the software.
8 //
9 //
10 //                        Intel License Agreement
11 //                For Open Source Computer Vision Library
12 //
13 // Copyright( C) 2000, Intel Corporation, all rights reserved.
14 // Third party copyrights are property of their respective owners.
15 //
16 // Redistribution and use in source and binary forms, with or without modification,
17 // are permitted provided that the following conditions are met:
18 //
19 //   * Redistribution's of source code must retain the above copyright notice,
20 //     this list of conditions and the following disclaimer.
21 //
22 //   * Redistribution's in binary form must reproduce the above copyright notice,
23 //     this list of conditions and the following disclaimer in the documentation
24 //     and/or other materials provided with the distribution.
25 //
26 //   * The name of Intel Corporation may not be used to endorse or promote products
27 //     derived from this software without specific prior written permission.
28 //
29 // This software is provided by the copyright holders and contributors "as is" and
30 // any express or implied warranties, including, but not limited to, the implied
31 // warranties of merchantability and fitness for a particular purpose are disclaimed.
32 // In no event shall the Intel Corporation or contributors be liable for any direct,
33 // indirect, incidental, special, exemplary, or consequential damages
34 //(including, but not limited to, procurement of substitute goods or services;
35 // loss of use, data, or profits; or business interruption) however caused
36 // and on any theory of liability, whether in contract, strict liability,
37 // or tort(including negligence or otherwise) arising in any way out of
38 // the use of this software, even ifadvised of the possibility of such damage.
39 //
40 //M*/
41 
42 #include "precomp.hpp"
43 
44 namespace cv
45 {
46 namespace ml
47 {
48 
49 const double minEigenValue = DBL_EPSILON;
50 
51 class CV_EXPORTS EMImpl CV_FINAL : public EM
52 {
53 public:
54 
55     int nclusters;
56     int covMatType;
57     TermCriteria termCrit;
58 
getTermCriteria() const59     inline TermCriteria getTermCriteria() const CV_OVERRIDE { return termCrit; }
setTermCriteria(const TermCriteria & val)60     inline void setTermCriteria(const TermCriteria& val) CV_OVERRIDE { termCrit = val; }
61 
setClustersNumber(int val)62     void setClustersNumber(int val) CV_OVERRIDE
63     {
64         nclusters = val;
65         CV_Assert(nclusters >= 1);
66     }
67 
getClustersNumber() const68     int getClustersNumber() const CV_OVERRIDE
69     {
70         return nclusters;
71     }
72 
setCovarianceMatrixType(int val)73     void setCovarianceMatrixType(int val) CV_OVERRIDE
74     {
75         covMatType = val;
76         CV_Assert(covMatType == COV_MAT_SPHERICAL ||
77                   covMatType == COV_MAT_DIAGONAL ||
78                   covMatType == COV_MAT_GENERIC);
79     }
80 
getCovarianceMatrixType() const81     int getCovarianceMatrixType() const CV_OVERRIDE
82     {
83         return covMatType;
84     }
85 
EMImpl()86     EMImpl()
87     {
88         nclusters = DEFAULT_NCLUSTERS;
89         covMatType=EM::COV_MAT_DIAGONAL;
90         termCrit = TermCriteria(TermCriteria::COUNT+TermCriteria::EPS, EM::DEFAULT_MAX_ITERS, 1e-6);
91     }
92 
~EMImpl()93     virtual ~EMImpl() {}
94 
clear()95     void clear() CV_OVERRIDE
96     {
97         trainSamples.release();
98         trainProbs.release();
99         trainLogLikelihoods.release();
100         trainLabels.release();
101 
102         weights.release();
103         means.release();
104         covs.clear();
105 
106         covsEigenValues.clear();
107         invCovsEigenValues.clear();
108         covsRotateMats.clear();
109 
110         logWeightDivDet.release();
111     }
112 
train(const Ptr<TrainData> & data,int)113     bool train(const Ptr<TrainData>& data, int) CV_OVERRIDE
114     {
115         CV_Assert(!data.empty());
116         Mat samples = data->getTrainSamples(), labels;
117         return trainEM(samples, labels, noArray(), noArray());
118     }
119 
trainEM(InputArray samples,OutputArray logLikelihoods,OutputArray labels,OutputArray probs)120     bool trainEM(InputArray samples,
121                OutputArray logLikelihoods,
122                OutputArray labels,
123                OutputArray probs) CV_OVERRIDE
124     {
125         Mat samplesMat = samples.getMat();
126         setTrainData(START_AUTO_STEP, samplesMat, 0, 0, 0, 0);
127         return doTrain(START_AUTO_STEP, logLikelihoods, labels, probs);
128     }
129 
trainE(InputArray samples,InputArray _means0,InputArray _covs0,InputArray _weights0,OutputArray logLikelihoods,OutputArray labels,OutputArray probs)130     bool trainE(InputArray samples,
131                 InputArray _means0,
132                 InputArray _covs0,
133                 InputArray _weights0,
134                 OutputArray logLikelihoods,
135                 OutputArray labels,
136                 OutputArray probs) CV_OVERRIDE
137     {
138         Mat samplesMat = samples.getMat();
139         std::vector<Mat> covs0;
140         _covs0.getMatVector(covs0);
141 
142         Mat means0 = _means0.getMat(), weights0 = _weights0.getMat();
143 
144         setTrainData(START_E_STEP, samplesMat, 0, !_means0.empty() ? &means0 : 0,
145                      !_covs0.empty() ? &covs0 : 0, !_weights0.empty() ? &weights0 : 0);
146         return doTrain(START_E_STEP, logLikelihoods, labels, probs);
147     }
148 
trainM(InputArray samples,InputArray _probs0,OutputArray logLikelihoods,OutputArray labels,OutputArray probs)149     bool trainM(InputArray samples,
150                 InputArray _probs0,
151                 OutputArray logLikelihoods,
152                 OutputArray labels,
153                 OutputArray probs) CV_OVERRIDE
154     {
155         Mat samplesMat = samples.getMat();
156         Mat probs0 = _probs0.getMat();
157 
158         setTrainData(START_M_STEP, samplesMat, !_probs0.empty() ? &probs0 : 0, 0, 0, 0);
159         return doTrain(START_M_STEP, logLikelihoods, labels, probs);
160     }
161 
predict(InputArray _inputs,OutputArray _outputs,int) const162     float predict(InputArray _inputs, OutputArray _outputs, int) const CV_OVERRIDE
163     {
164         bool needprobs = _outputs.needed();
165         Mat samples = _inputs.getMat(), probs, probsrow;
166         int ptype = CV_64F;
167         float firstres = 0.f;
168         int i, nsamples = samples.rows;
169 
170         if( needprobs )
171         {
172             if( _outputs.fixedType() )
173                 ptype = _outputs.type();
174             _outputs.create(samples.rows, nclusters, ptype);
175             probs = _outputs.getMat();
176         }
177         else
178             nsamples = std::min(nsamples, 1);
179 
180         for( i = 0; i < nsamples; i++ )
181         {
182             if( needprobs )
183                 probsrow = probs.row(i);
184             Vec2d res = computeProbabilities(samples.row(i), needprobs ? &probsrow : 0, ptype);
185             if( i == 0 )
186                 firstres = (float)res[1];
187         }
188         return firstres;
189     }
190 
predict2(InputArray _sample,OutputArray _probs) const191     Vec2d predict2(InputArray _sample, OutputArray _probs) const CV_OVERRIDE
192     {
193         int ptype = CV_64F;
194         Mat sample = _sample.getMat();
195         CV_Assert(isTrained());
196 
197         CV_Assert(!sample.empty());
198         if(sample.type() != CV_64FC1)
199         {
200             Mat tmp;
201             sample.convertTo(tmp, CV_64FC1);
202             sample = tmp;
203         }
204         sample = sample.reshape(1, 1);
205 
206         Mat probs;
207         if( _probs.needed() )
208         {
209             if( _probs.fixedType() )
210                 ptype = _probs.type();
211             _probs.create(1, nclusters, ptype);
212             probs = _probs.getMat();
213         }
214 
215         return computeProbabilities(sample, !probs.empty() ? &probs : 0, ptype);
216     }
217 
isTrained() const218     bool isTrained() const CV_OVERRIDE
219     {
220         return !means.empty();
221     }
222 
isClassifier() const223     bool isClassifier() const CV_OVERRIDE
224     {
225         return true;
226     }
227 
getVarCount() const228     int getVarCount() const CV_OVERRIDE
229     {
230         return means.cols;
231     }
232 
getDefaultName() const233     String getDefaultName() const CV_OVERRIDE
234     {
235         return "opencv_ml_em";
236     }
237 
checkTrainData(int startStep,const Mat & samples,int nclusters,int covMatType,const Mat * probs,const Mat * means,const std::vector<Mat> * covs,const Mat * weights)238     static void checkTrainData(int startStep, const Mat& samples,
239                                int nclusters, int covMatType, const Mat* probs, const Mat* means,
240                                const std::vector<Mat>* covs, const Mat* weights)
241     {
242         // Check samples.
243         CV_Assert(!samples.empty());
244         CV_Assert(samples.channels() == 1);
245 
246         int nsamples = samples.rows;
247         int dim = samples.cols;
248 
249         // Check training params.
250         CV_Assert(nclusters > 0);
251         CV_Assert(nclusters <= nsamples);
252         CV_Assert(startStep == START_AUTO_STEP ||
253                   startStep == START_E_STEP ||
254                   startStep == START_M_STEP);
255         CV_Assert(covMatType == COV_MAT_GENERIC ||
256                   covMatType == COV_MAT_DIAGONAL ||
257                   covMatType == COV_MAT_SPHERICAL);
258 
259         CV_Assert(!probs ||
260             (!probs->empty() &&
261              probs->rows == nsamples && probs->cols == nclusters &&
262              (probs->type() == CV_32FC1 || probs->type() == CV_64FC1)));
263 
264         CV_Assert(!weights ||
265             (!weights->empty() &&
266              (weights->cols == 1 || weights->rows == 1) && static_cast<int>(weights->total()) == nclusters &&
267              (weights->type() == CV_32FC1 || weights->type() == CV_64FC1)));
268 
269         CV_Assert(!means ||
270             (!means->empty() &&
271              means->rows == nclusters && means->cols == dim &&
272              means->channels() == 1));
273 
274         CV_Assert(!covs ||
275             (!covs->empty() &&
276              static_cast<int>(covs->size()) == nclusters));
277         if(covs)
278         {
279             const Size covSize(dim, dim);
280             for(size_t i = 0; i < covs->size(); i++)
281             {
282                 const Mat& m = (*covs)[i];
283                 CV_Assert(!m.empty() && m.size() == covSize && (m.channels() == 1));
284             }
285         }
286 
287         if(startStep == START_E_STEP)
288         {
289             CV_Assert(means);
290         }
291         else if(startStep == START_M_STEP)
292         {
293             CV_Assert(probs);
294         }
295     }
296 
preprocessSampleData(const Mat & src,Mat & dst,int dstType,bool isAlwaysClone)297     static void preprocessSampleData(const Mat& src, Mat& dst, int dstType, bool isAlwaysClone)
298     {
299         if(src.type() == dstType && !isAlwaysClone)
300             dst = src;
301         else
302             src.convertTo(dst, dstType);
303     }
304 
preprocessProbability(Mat & probs)305     static void preprocessProbability(Mat& probs)
306     {
307         max(probs, 0., probs);
308 
309         const double uniformProbability = (double)(1./probs.cols);
310         for(int y = 0; y < probs.rows; y++)
311         {
312             Mat sampleProbs = probs.row(y);
313 
314             double maxVal = 0;
315             minMaxLoc(sampleProbs, 0, &maxVal);
316             if(maxVal < FLT_EPSILON)
317                 sampleProbs.setTo(uniformProbability);
318             else
319                 normalize(sampleProbs, sampleProbs, 1, 0, NORM_L1);
320         }
321     }
322 
setTrainData(int startStep,const Mat & samples,const Mat * probs0,const Mat * means0,const std::vector<Mat> * covs0,const Mat * weights0)323     void setTrainData(int startStep, const Mat& samples,
324                       const Mat* probs0,
325                       const Mat* means0,
326                       const std::vector<Mat>* covs0,
327                       const Mat* weights0)
328     {
329         clear();
330 
331         checkTrainData(startStep, samples, nclusters, covMatType, probs0, means0, covs0, weights0);
332 
333         bool isKMeansInit = (startStep == START_AUTO_STEP) || (startStep == START_E_STEP && (covs0 == 0 || weights0 == 0));
334         // Set checked data
335         preprocessSampleData(samples, trainSamples, isKMeansInit ? CV_32FC1 : CV_64FC1, false);
336 
337         // set probs
338         if(probs0 && startStep == START_M_STEP)
339         {
340             preprocessSampleData(*probs0, trainProbs, CV_64FC1, true);
341             preprocessProbability(trainProbs);
342         }
343 
344         // set weights
345         if(weights0 && (startStep == START_E_STEP && covs0))
346         {
347             weights0->convertTo(weights, CV_64FC1);
348             weights = weights.reshape(1,1);
349             preprocessProbability(weights);
350         }
351 
352         // set means
353         if(means0 && (startStep == START_E_STEP/* || startStep == START_AUTO_STEP*/))
354             means0->convertTo(means, isKMeansInit ? CV_32FC1 : CV_64FC1);
355 
356         // set covs
357         if(covs0 && (startStep == START_E_STEP && weights0))
358         {
359             covs.resize(nclusters);
360             for(size_t i = 0; i < covs0->size(); i++)
361                 (*covs0)[i].convertTo(covs[i], CV_64FC1);
362         }
363     }
364 
decomposeCovs()365     void decomposeCovs()
366     {
367         CV_Assert(!covs.empty());
368         covsEigenValues.resize(nclusters);
369         if(covMatType == COV_MAT_GENERIC)
370             covsRotateMats.resize(nclusters);
371         invCovsEigenValues.resize(nclusters);
372         for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
373         {
374             CV_Assert(!covs[clusterIndex].empty());
375 
376             SVD svd(covs[clusterIndex], SVD::MODIFY_A + SVD::FULL_UV);
377 
378             if(covMatType == COV_MAT_SPHERICAL)
379             {
380                 double maxSingularVal = svd.w.at<double>(0);
381                 covsEigenValues[clusterIndex] = Mat(1, 1, CV_64FC1, Scalar(maxSingularVal));
382             }
383             else if(covMatType == COV_MAT_DIAGONAL)
384             {
385                 covsEigenValues[clusterIndex] = covs[clusterIndex].diag().clone(); //Preserve the original order of eigen values.
386             }
387             else //COV_MAT_GENERIC
388             {
389                 covsEigenValues[clusterIndex] = svd.w;
390                 covsRotateMats[clusterIndex] = svd.u;
391             }
392             max(covsEigenValues[clusterIndex], minEigenValue, covsEigenValues[clusterIndex]);
393             invCovsEigenValues[clusterIndex] = 1./covsEigenValues[clusterIndex];
394         }
395     }
396 
clusterTrainSamples()397     void clusterTrainSamples()
398     {
399         int nsamples = trainSamples.rows;
400 
401         // Cluster samples, compute/update means
402 
403         // Convert samples and means to 32F, because kmeans requires this type.
404         Mat trainSamplesFlt, meansFlt;
405         if(trainSamples.type() != CV_32FC1)
406             trainSamples.convertTo(trainSamplesFlt, CV_32FC1);
407         else
408             trainSamplesFlt = trainSamples;
409         if(!means.empty())
410         {
411             if(means.type() != CV_32FC1)
412                 means.convertTo(meansFlt, CV_32FC1);
413             else
414                 meansFlt = means;
415         }
416 
417         Mat labels;
418         kmeans(trainSamplesFlt, nclusters, labels,
419                TermCriteria(TermCriteria::COUNT, means.empty() ? 10 : 1, 0.5),
420                10, KMEANS_PP_CENTERS, meansFlt);
421 
422         // Convert samples and means back to 64F.
423         CV_Assert(meansFlt.type() == CV_32FC1);
424         if(trainSamples.type() != CV_64FC1)
425         {
426             Mat trainSamplesBuffer;
427             trainSamplesFlt.convertTo(trainSamplesBuffer, CV_64FC1);
428             trainSamples = trainSamplesBuffer;
429         }
430         meansFlt.convertTo(means, CV_64FC1);
431 
432         // Compute weights and covs
433         weights = Mat(1, nclusters, CV_64FC1, Scalar(0));
434         covs.resize(nclusters);
435         for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
436         {
437             Mat clusterSamples;
438             for(int sampleIndex = 0; sampleIndex < nsamples; sampleIndex++)
439             {
440                 if(labels.at<int>(sampleIndex) == clusterIndex)
441                 {
442                     const Mat sample = trainSamples.row(sampleIndex);
443                     clusterSamples.push_back(sample);
444                 }
445             }
446             CV_Assert(!clusterSamples.empty());
447 
448             calcCovarMatrix(clusterSamples, covs[clusterIndex], means.row(clusterIndex),
449                 CV_COVAR_NORMAL + CV_COVAR_ROWS + CV_COVAR_USE_AVG + CV_COVAR_SCALE, CV_64FC1);
450             weights.at<double>(clusterIndex) = static_cast<double>(clusterSamples.rows)/static_cast<double>(nsamples);
451         }
452 
453         decomposeCovs();
454     }
455 
computeLogWeightDivDet()456     void computeLogWeightDivDet()
457     {
458         CV_Assert(!covsEigenValues.empty());
459 
460         Mat logWeights;
461         cv::max(weights, DBL_MIN, weights);
462         log(weights, logWeights);
463 
464         logWeightDivDet.create(1, nclusters, CV_64FC1);
465         // note: logWeightDivDet = log(weight_k) - 0.5 * log(|det(cov_k)|)
466 
467         for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
468         {
469             double logDetCov = 0.;
470             const int evalCount = static_cast<int>(covsEigenValues[clusterIndex].total());
471             for(int di = 0; di < evalCount; di++)
472                 logDetCov += std::log(covsEigenValues[clusterIndex].at<double>(covMatType != COV_MAT_SPHERICAL ? di : 0));
473 
474             logWeightDivDet.at<double>(clusterIndex) = logWeights.at<double>(clusterIndex) - 0.5 * logDetCov;
475         }
476     }
477 
doTrain(int startStep,OutputArray logLikelihoods,OutputArray labels,OutputArray probs)478     bool doTrain(int startStep, OutputArray logLikelihoods, OutputArray labels, OutputArray probs)
479     {
480         int dim = trainSamples.cols;
481         // Precompute the empty initial train data in the cases of START_E_STEP and START_AUTO_STEP
482         if(startStep != START_M_STEP)
483         {
484             if(covs.empty())
485             {
486                 CV_Assert(weights.empty());
487                 clusterTrainSamples();
488             }
489         }
490 
491         if(!covs.empty() && covsEigenValues.empty() )
492         {
493             CV_Assert(invCovsEigenValues.empty());
494             decomposeCovs();
495         }
496 
497         if(startStep == START_M_STEP)
498             mStep();
499 
500         double trainLogLikelihood, prevTrainLogLikelihood = 0.;
501         int maxIters = (termCrit.type & TermCriteria::MAX_ITER) ?
502             termCrit.maxCount : DEFAULT_MAX_ITERS;
503         double epsilon = (termCrit.type & TermCriteria::EPS) ? termCrit.epsilon : 0.;
504 
505         for(int iter = 0; ; iter++)
506         {
507             eStep();
508             trainLogLikelihood = sum(trainLogLikelihoods)[0];
509 
510             if(iter >= maxIters - 1)
511                 break;
512 
513             double trainLogLikelihoodDelta = trainLogLikelihood - prevTrainLogLikelihood;
514             if( iter != 0 &&
515                 (trainLogLikelihoodDelta < -DBL_EPSILON ||
516                  trainLogLikelihoodDelta < epsilon * std::fabs(trainLogLikelihood)))
517                 break;
518 
519             mStep();
520 
521             prevTrainLogLikelihood = trainLogLikelihood;
522         }
523 
524         if( trainLogLikelihood <= -DBL_MAX/10000. )
525         {
526             clear();
527             return false;
528         }
529 
530         // postprocess covs
531         covs.resize(nclusters);
532         for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
533         {
534             if(covMatType == COV_MAT_SPHERICAL)
535             {
536                 covs[clusterIndex].create(dim, dim, CV_64FC1);
537                 setIdentity(covs[clusterIndex], Scalar(covsEigenValues[clusterIndex].at<double>(0)));
538             }
539             else if(covMatType == COV_MAT_DIAGONAL)
540             {
541                 covs[clusterIndex] = Mat::diag(covsEigenValues[clusterIndex]);
542             }
543         }
544 
545         if(labels.needed())
546             trainLabels.copyTo(labels);
547         if(probs.needed())
548             trainProbs.copyTo(probs);
549         if(logLikelihoods.needed())
550             trainLogLikelihoods.copyTo(logLikelihoods);
551 
552         trainSamples.release();
553         trainProbs.release();
554         trainLabels.release();
555         trainLogLikelihoods.release();
556 
557         return true;
558     }
559 
computeProbabilities(const Mat & sample,Mat * probs,int ptype) const560     Vec2d computeProbabilities(const Mat& sample, Mat* probs, int ptype) const
561     {
562         // L_ik = log(weight_k) - 0.5 * log(|det(cov_k)|) - 0.5 *(x_i - mean_k)' cov_k^(-1) (x_i - mean_k)]
563         // q = arg(max_k(L_ik))
564         // probs_ik = exp(L_ik - L_iq) / (1 + sum_j!=q (exp(L_ij - L_iq))
565         // see Alex Smola's blog http://blog.smola.org/page/2 for
566         // details on the log-sum-exp trick
567 
568         int stype = sample.type();
569         CV_Assert(!means.empty());
570         CV_Assert((stype == CV_32F || stype == CV_64F) && (ptype == CV_32F || ptype == CV_64F));
571         CV_Assert(sample.size() == Size(means.cols, 1));
572 
573         int dim = sample.cols;
574 
575         Mat L(1, nclusters, CV_64FC1), centeredSample(1, dim, CV_64F);
576         int i, label = 0;
577         for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
578         {
579             const double* mptr = means.ptr<double>(clusterIndex);
580             double* dptr = centeredSample.ptr<double>();
581             if( stype == CV_32F )
582             {
583                 const float* sptr = sample.ptr<float>();
584                 for( i = 0; i < dim; i++ )
585                     dptr[i] = sptr[i] - mptr[i];
586             }
587             else
588             {
589                 const double* sptr = sample.ptr<double>();
590                 for( i = 0; i < dim; i++ )
591                     dptr[i] = sptr[i] - mptr[i];
592             }
593 
594             Mat rotatedCenteredSample = covMatType != COV_MAT_GENERIC ?
595                     centeredSample : centeredSample * covsRotateMats[clusterIndex];
596 
597             double Lval = 0;
598             for(int di = 0; di < dim; di++)
599             {
600                 double w = invCovsEigenValues[clusterIndex].at<double>(covMatType != COV_MAT_SPHERICAL ? di : 0);
601                 double val = rotatedCenteredSample.at<double>(di);
602                 Lval += w * val * val;
603             }
604             CV_DbgAssert(!logWeightDivDet.empty());
605             L.at<double>(clusterIndex) = logWeightDivDet.at<double>(clusterIndex) - 0.5 * Lval;
606 
607             if(L.at<double>(clusterIndex) > L.at<double>(label))
608                 label = clusterIndex;
609         }
610 
611         double maxLVal = L.at<double>(label);
612         double expDiffSum = 0;
613         for( i = 0; i < L.cols; i++ )
614         {
615             double v = std::exp(L.at<double>(i) - maxLVal);
616             L.at<double>(i) = v;
617             expDiffSum += v; // sum_j(exp(L_ij - L_iq))
618         }
619 
620         CV_Assert(expDiffSum > 0);
621         if(probs)
622             L.convertTo(*probs, ptype, 1./expDiffSum);
623 
624         Vec2d res;
625         res[0] = std::log(expDiffSum)  + maxLVal - 0.5 * dim * CV_LOG2PI;
626         res[1] = label;
627 
628         return res;
629     }
630 
eStep()631     void eStep()
632     {
633         // Compute probs_ik from means_k, covs_k and weights_k.
634         trainProbs.create(trainSamples.rows, nclusters, CV_64FC1);
635         trainLabels.create(trainSamples.rows, 1, CV_32SC1);
636         trainLogLikelihoods.create(trainSamples.rows, 1, CV_64FC1);
637 
638         computeLogWeightDivDet();
639 
640         CV_DbgAssert(trainSamples.type() == CV_64FC1);
641         CV_DbgAssert(means.type() == CV_64FC1);
642 
643         for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
644         {
645             Mat sampleProbs = trainProbs.row(sampleIndex);
646             Vec2d res = computeProbabilities(trainSamples.row(sampleIndex), &sampleProbs, CV_64F);
647             trainLogLikelihoods.at<double>(sampleIndex) = res[0];
648             trainLabels.at<int>(sampleIndex) = static_cast<int>(res[1]);
649         }
650     }
651 
mStep()652     void mStep()
653     {
654         // Update means_k, covs_k and weights_k from probs_ik
655         int dim = trainSamples.cols;
656 
657         // Update weights
658         // not normalized first
659         reduce(trainProbs, weights, 0, CV_REDUCE_SUM);
660 
661         // Update means
662         means.create(nclusters, dim, CV_64FC1);
663         means = Scalar(0);
664 
665         const double minPosWeight = trainSamples.rows * DBL_EPSILON;
666         double minWeight = DBL_MAX;
667         int minWeightClusterIndex = -1;
668         for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
669         {
670             if(weights.at<double>(clusterIndex) <= minPosWeight)
671                 continue;
672 
673             if(weights.at<double>(clusterIndex) < minWeight)
674             {
675                 minWeight = weights.at<double>(clusterIndex);
676                 minWeightClusterIndex = clusterIndex;
677             }
678 
679             Mat clusterMean = means.row(clusterIndex);
680             for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
681                 clusterMean += trainProbs.at<double>(sampleIndex, clusterIndex) * trainSamples.row(sampleIndex);
682             clusterMean /= weights.at<double>(clusterIndex);
683         }
684 
685         // Update covsEigenValues and invCovsEigenValues
686         covs.resize(nclusters);
687         covsEigenValues.resize(nclusters);
688         if(covMatType == COV_MAT_GENERIC)
689             covsRotateMats.resize(nclusters);
690         invCovsEigenValues.resize(nclusters);
691         for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
692         {
693             if(weights.at<double>(clusterIndex) <= minPosWeight)
694                 continue;
695 
696             if(covMatType != COV_MAT_SPHERICAL)
697                 covsEigenValues[clusterIndex].create(1, dim, CV_64FC1);
698             else
699                 covsEigenValues[clusterIndex].create(1, 1, CV_64FC1);
700 
701             if(covMatType == COV_MAT_GENERIC)
702                 covs[clusterIndex].create(dim, dim, CV_64FC1);
703 
704             Mat clusterCov = covMatType != COV_MAT_GENERIC ?
705                 covsEigenValues[clusterIndex] : covs[clusterIndex];
706 
707             clusterCov = Scalar(0);
708 
709             Mat centeredSample;
710             for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
711             {
712                 centeredSample = trainSamples.row(sampleIndex) - means.row(clusterIndex);
713 
714                 if(covMatType == COV_MAT_GENERIC)
715                     clusterCov += trainProbs.at<double>(sampleIndex, clusterIndex) * centeredSample.t() * centeredSample;
716                 else
717                 {
718                     double p = trainProbs.at<double>(sampleIndex, clusterIndex);
719                     for(int di = 0; di < dim; di++ )
720                     {
721                         double val = centeredSample.at<double>(di);
722                         clusterCov.at<double>(covMatType != COV_MAT_SPHERICAL ? di : 0) += p*val*val;
723                     }
724                 }
725             }
726 
727             if(covMatType == COV_MAT_SPHERICAL)
728                 clusterCov /= dim;
729 
730             clusterCov /= weights.at<double>(clusterIndex);
731 
732             // Update covsRotateMats for COV_MAT_GENERIC only
733             if(covMatType == COV_MAT_GENERIC)
734             {
735                 SVD svd(covs[clusterIndex], SVD::MODIFY_A + SVD::FULL_UV);
736                 covsEigenValues[clusterIndex] = svd.w;
737                 covsRotateMats[clusterIndex] = svd.u;
738             }
739 
740             max(covsEigenValues[clusterIndex], minEigenValue, covsEigenValues[clusterIndex]);
741 
742             // update invCovsEigenValues
743             invCovsEigenValues[clusterIndex] = 1./covsEigenValues[clusterIndex];
744         }
745 
746         for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
747         {
748             if(weights.at<double>(clusterIndex) <= minPosWeight)
749             {
750                 Mat clusterMean = means.row(clusterIndex);
751                 means.row(minWeightClusterIndex).copyTo(clusterMean);
752                 covs[minWeightClusterIndex].copyTo(covs[clusterIndex]);
753                 covsEigenValues[minWeightClusterIndex].copyTo(covsEigenValues[clusterIndex]);
754                 if(covMatType == COV_MAT_GENERIC)
755                     covsRotateMats[minWeightClusterIndex].copyTo(covsRotateMats[clusterIndex]);
756                 invCovsEigenValues[minWeightClusterIndex].copyTo(invCovsEigenValues[clusterIndex]);
757             }
758         }
759 
760         // Normalize weights
761         weights /= trainSamples.rows;
762     }
763 
write_params(FileStorage & fs) const764     void write_params(FileStorage& fs) const
765     {
766         fs << "nclusters" << nclusters;
767         fs << "cov_mat_type" << (covMatType == COV_MAT_SPHERICAL ? String("spherical") :
768                                  covMatType == COV_MAT_DIAGONAL ? String("diagonal") :
769                                  covMatType == COV_MAT_GENERIC ? String("generic") :
770                                  format("unknown_%d", covMatType));
771         writeTermCrit(fs, termCrit);
772     }
773 
write(FileStorage & fs) const774     void write(FileStorage& fs) const CV_OVERRIDE
775     {
776         writeFormat(fs);
777         fs << "training_params" << "{";
778         write_params(fs);
779         fs << "}";
780         fs << "weights" << weights;
781         fs << "means" << means;
782 
783         size_t i, n = covs.size();
784 
785         fs << "covs" << "[";
786         for( i = 0; i < n; i++ )
787             fs << covs[i];
788         fs << "]";
789     }
790 
read_params(const FileNode & fn)791     void read_params(const FileNode& fn)
792     {
793         nclusters = (int)fn["nclusters"];
794         String s = (String)fn["cov_mat_type"];
795         covMatType = s == "spherical" ? COV_MAT_SPHERICAL :
796                              s == "diagonal" ? COV_MAT_DIAGONAL :
797                              s == "generic" ? COV_MAT_GENERIC : -1;
798         CV_Assert(covMatType >= 0);
799         termCrit = readTermCrit(fn);
800     }
801 
read(const FileNode & fn)802     void read(const FileNode& fn) CV_OVERRIDE
803     {
804         clear();
805         read_params(fn["training_params"]);
806 
807         fn["weights"] >> weights;
808         fn["means"] >> means;
809 
810         FileNode cfn = fn["covs"];
811         FileNodeIterator cfn_it = cfn.begin();
812         int i, n = (int)cfn.size();
813         covs.resize(n);
814 
815         for( i = 0; i < n; i++, ++cfn_it )
816             (*cfn_it) >> covs[i];
817 
818         decomposeCovs();
819         computeLogWeightDivDet();
820     }
821 
getWeights() const822     Mat getWeights() const CV_OVERRIDE { return weights; }
getMeans() const823     Mat getMeans() const CV_OVERRIDE { return means; }
getCovs(std::vector<Mat> & _covs) const824     void getCovs(std::vector<Mat>& _covs) const CV_OVERRIDE
825     {
826         _covs.resize(covs.size());
827         std::copy(covs.begin(), covs.end(), _covs.begin());
828     }
829 
830     // all inner matrices have type CV_64FC1
831     Mat trainSamples;
832     Mat trainProbs;
833     Mat trainLogLikelihoods;
834     Mat trainLabels;
835 
836     Mat weights;
837     Mat means;
838     std::vector<Mat> covs;
839 
840     std::vector<Mat> covsEigenValues;
841     std::vector<Mat> covsRotateMats;
842     std::vector<Mat> invCovsEigenValues;
843     Mat logWeightDivDet;
844 };
845 
create()846 Ptr<EM> EM::create()
847 {
848     return makePtr<EMImpl>();
849 }
850 
load(const String & filepath,const String & nodeName)851 Ptr<EM> EM::load(const String& filepath, const String& nodeName)
852 {
853     return Algorithm::load<EM>(filepath, nodeName);
854 }
855 
856 }
857 } // namespace cv
858 
859 /* End of file. */
860