1 // This file is part of OpenCV project.
2 // It is subject to the license terms in the LICENSE file found in the top-level directory
3 // of this distribution and at http://opencv.org/license.html.
4 #include "test_precomp.hpp"
5
6 namespace opencv_test {
7
defaultDistribs(Mat & means,vector<Mat> & covs,int type)8 void defaultDistribs( Mat& means, vector<Mat>& covs, int type)
9 {
10 float mp0[] = {0.0f, 0.0f}, cp0[] = {0.67f, 0.0f, 0.0f, 0.67f};
11 float mp1[] = {5.0f, 0.0f}, cp1[] = {1.0f, 0.0f, 0.0f, 1.0f};
12 float mp2[] = {1.0f, 5.0f}, cp2[] = {1.0f, 0.0f, 0.0f, 1.0f};
13 means.create(3, 2, type);
14 Mat m0( 1, 2, CV_32FC1, mp0 ), c0( 2, 2, CV_32FC1, cp0 );
15 Mat m1( 1, 2, CV_32FC1, mp1 ), c1( 2, 2, CV_32FC1, cp1 );
16 Mat m2( 1, 2, CV_32FC1, mp2 ), c2( 2, 2, CV_32FC1, cp2 );
17 means.resize(3), covs.resize(3);
18
19 Mat mr0 = means.row(0);
20 m0.convertTo(mr0, type);
21 c0.convertTo(covs[0], type);
22
23 Mat mr1 = means.row(1);
24 m1.convertTo(mr1, type);
25 c1.convertTo(covs[1], type);
26
27 Mat mr2 = means.row(2);
28 m2.convertTo(mr2, type);
29 c2.convertTo(covs[2], type);
30 }
31
32 // generate points sets by normal distributions
generateData(Mat & data,Mat & labels,const vector<int> & sizes,const Mat & _means,const vector<Mat> & covs,int dataType,int labelType)33 void generateData( Mat& data, Mat& labels, const vector<int>& sizes, const Mat& _means, const vector<Mat>& covs, int dataType, int labelType )
34 {
35 vector<int>::const_iterator sit = sizes.begin();
36 int total = 0;
37 for( ; sit != sizes.end(); ++sit )
38 total += *sit;
39 CV_Assert( _means.rows == (int)sizes.size() && covs.size() == sizes.size() );
40 CV_Assert( !data.empty() && data.rows == total );
41 CV_Assert( data.type() == dataType );
42
43 labels.create( data.rows, 1, labelType );
44
45 randn( data, Scalar::all(-1.0), Scalar::all(1.0) );
46 vector<Mat> means(sizes.size());
47 for(int i = 0; i < _means.rows; i++)
48 means[i] = _means.row(i);
49 vector<Mat>::const_iterator mit = means.begin(), cit = covs.begin();
50 int bi, ei = 0;
51 sit = sizes.begin();
52 for( int p = 0, l = 0; sit != sizes.end(); ++sit, ++mit, ++cit, l++ )
53 {
54 bi = ei;
55 ei = bi + *sit;
56 CV_Assert( mit->rows == 1 && mit->cols == data.cols );
57 CV_Assert( cit->rows == data.cols && cit->cols == data.cols );
58 for( int i = bi; i < ei; i++, p++ )
59 {
60 Mat r = data.row(i);
61 r = r * (*cit) + *mit;
62 if( labelType == CV_32FC1 )
63 labels.at<float>(p, 0) = (float)l;
64 else if( labelType == CV_32SC1 )
65 labels.at<int>(p, 0) = l;
66 else
67 {
68 CV_DbgAssert(0);
69 }
70 }
71 }
72 }
73
maxIdx(const vector<int> & count)74 int maxIdx( const vector<int>& count )
75 {
76 int idx = -1;
77 int maxVal = -1;
78 vector<int>::const_iterator it = count.begin();
79 for( int i = 0; it != count.end(); ++it, i++ )
80 {
81 if( *it > maxVal)
82 {
83 maxVal = *it;
84 idx = i;
85 }
86 }
87 CV_Assert( idx >= 0);
88 return idx;
89 }
90
getLabelsMap(const Mat & labels,const vector<int> & sizes,vector<int> & labelsMap,bool checkClusterUniq)91 bool getLabelsMap( const Mat& labels, const vector<int>& sizes, vector<int>& labelsMap, bool checkClusterUniq)
92 {
93 size_t total = 0, nclusters = sizes.size();
94 for(size_t i = 0; i < sizes.size(); i++)
95 total += sizes[i];
96
97 CV_Assert( !labels.empty() );
98 CV_Assert( labels.total() == total && (labels.cols == 1 || labels.rows == 1));
99 CV_Assert( labels.type() == CV_32SC1 || labels.type() == CV_32FC1 );
100
101 bool isFlt = labels.type() == CV_32FC1;
102
103 labelsMap.resize(nclusters);
104
105 vector<bool> buzy(nclusters, false);
106 int startIndex = 0;
107 for( size_t clusterIndex = 0; clusterIndex < sizes.size(); clusterIndex++ )
108 {
109 vector<int> count( nclusters, 0 );
110 for( int i = startIndex; i < startIndex + sizes[clusterIndex]; i++)
111 {
112 int lbl = isFlt ? (int)labels.at<float>(i) : labels.at<int>(i);
113 CV_Assert(lbl < (int)nclusters);
114 count[lbl]++;
115 CV_Assert(count[lbl] < (int)total);
116 }
117 startIndex += sizes[clusterIndex];
118
119 int cls = maxIdx( count );
120 CV_Assert( !checkClusterUniq || !buzy[cls] );
121
122 labelsMap[clusterIndex] = cls;
123
124 buzy[cls] = true;
125 }
126
127 if(checkClusterUniq)
128 {
129 for(size_t i = 0; i < buzy.size(); i++)
130 if(!buzy[i])
131 return false;
132 }
133
134 return true;
135 }
136
calcErr(const Mat & labels,const Mat & origLabels,const vector<int> & sizes,float & err,bool labelsEquivalent,bool checkClusterUniq)137 bool calcErr( const Mat& labels, const Mat& origLabels, const vector<int>& sizes, float& err, bool labelsEquivalent, bool checkClusterUniq)
138 {
139 err = 0;
140 CV_Assert( !labels.empty() && !origLabels.empty() );
141 CV_Assert( labels.rows == 1 || labels.cols == 1 );
142 CV_Assert( origLabels.rows == 1 || origLabels.cols == 1 );
143 CV_Assert( labels.total() == origLabels.total() );
144 CV_Assert( labels.type() == CV_32SC1 || labels.type() == CV_32FC1 );
145 CV_Assert( origLabels.type() == labels.type() );
146
147 vector<int> labelsMap;
148 bool isFlt = labels.type() == CV_32FC1;
149 if( !labelsEquivalent )
150 {
151 if( !getLabelsMap( labels, sizes, labelsMap, checkClusterUniq ) )
152 return false;
153
154 for( int i = 0; i < labels.rows; i++ )
155 if( isFlt )
156 err += labels.at<float>(i) != labelsMap[(int)origLabels.at<float>(i)] ? 1.f : 0.f;
157 else
158 err += labels.at<int>(i) != labelsMap[origLabels.at<int>(i)] ? 1.f : 0.f;
159 }
160 else
161 {
162 for( int i = 0; i < labels.rows; i++ )
163 if( isFlt )
164 err += labels.at<float>(i) != origLabels.at<float>(i) ? 1.f : 0.f;
165 else
166 err += labels.at<int>(i) != origLabels.at<int>(i) ? 1.f : 0.f;
167 }
168 err /= (float)labels.rows;
169 return true;
170 }
171
calculateError(const Mat & _p_labels,const Mat & _o_labels,float & error)172 bool calculateError( const Mat& _p_labels, const Mat& _o_labels, float& error)
173 {
174 error = 0.0f;
175 float accuracy = 0.0f;
176 Mat _p_labels_temp;
177 Mat _o_labels_temp;
178 _p_labels.convertTo(_p_labels_temp, CV_32S);
179 _o_labels.convertTo(_o_labels_temp, CV_32S);
180
181 CV_Assert(_p_labels_temp.total() == _o_labels_temp.total());
182 CV_Assert(_p_labels_temp.rows == _o_labels_temp.rows);
183
184 accuracy = (float)countNonZero(_p_labels_temp == _o_labels_temp)/_p_labels_temp.rows;
185 error = 1 - accuracy;
186 return true;
187 }
188
189 } // namespace
190