1 /*=========================================================================
2 *
3 * Copyright Insight Software Consortium
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0.txt
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 *=========================================================================*/
18
19 // Insight classes
20
21 #include "itkImageKmeansModelEstimator.h"
22 #include "itkDistanceToCentroidMembershipFunction.h"
23
24 //Data definitions
25 #define IMGWIDTH 16
26 #define IMGHEIGHT 1
27 #define NFRAMES 1
28 #define NUMBANDS 2
29 #define NDIMENSION 3
30
31 #define CDBKWIDTH 4
32 #define CDBKHEIGHT 1
33 #define NFRAMES 1
34 #define NCODEWORDS CDBKWIDTH * CDBKHEIGHT * NFRAMES
35 #define NUMBANDS 2
36 #define NDIMENSION 3
37 #define STARTFRAME 0
38 #define NUM_BYTES_PER_PIXEL 1
39 #define ONEBAND 1
40
41
42 // class to support progress feeback
43 class ShowProgressObject
44 {
45 public:
ShowProgressObject(itk::LightProcessObject * o)46 ShowProgressObject(itk::LightProcessObject * o)
47 {m_Process = o;}
ShowProgress()48 void ShowProgress()
49 {std::cout << "Progress " << m_Process->GetProgress() << std::endl;}
50 itk::LightProcessObject::Pointer m_Process;
51 };
52
itkKmeansModelEstimatorTest(int,char * [])53 int itkKmeansModelEstimatorTest(int, char* [] )
54 {
55 //------------------------------------------------------
56 //Create a simple test vector with 16 entries and 2 bands
57 //------------------------------------------------------
58 using VecImageType = itk::Image<itk::Vector<double,NUMBANDS>,NDIMENSION>;
59
60 using VecImagePixelType = VecImageType::PixelType;
61
62 VecImageType::Pointer vecImage = VecImageType::New();
63
64 VecImageType::SizeType vecImgSize = {{ IMGWIDTH , IMGHEIGHT, NFRAMES }};
65
66 VecImageType::IndexType index;
67 index.Fill(0);
68 VecImageType::RegionType region;
69
70 region.SetSize( vecImgSize );
71 region.SetIndex( index );
72
73 vecImage->SetLargestPossibleRegion( region );
74 vecImage->SetBufferedRegion( region );
75 vecImage->Allocate();
76
77 // setup the iterators
78 enum { VecImageDimension = VecImageType::ImageDimension };
79 using VecIterator = itk::ImageRegionIterator<VecImageType>;
80
81 VecIterator outIt( vecImage, vecImage->GetBufferedRegion() );
82
83 //--------------------------------------------------------------------------
84 //Manually create and store each vector
85 //--------------------------------------------------------------------------
86
87 //Vector no. 1
88 VecImagePixelType vec;
89 vec[0] = 21; vec[1] = 9; outIt.Set( vec ); ++outIt;
90 //Vector no. 2
91 vec[0] = 10; vec[1] = 20; outIt.Set( vec ); ++outIt;
92 //Vector no. 3
93 vec[0] = 8; vec[1] = 21; outIt.Set( vec ); ++outIt;
94 //Vector no. 4
95 vec[0] = 10; vec[1] = 23; outIt.Set( vec ); ++outIt;
96 //Vector no. 5
97 vec[0] = 12; vec[1] = 21; outIt.Set( vec ); ++outIt;
98 //Vector no. 6
99 vec[0] = 11; vec[1] = 12; outIt.Set( vec ); ++outIt;
100 //Vector no. 7
101 vec[0] = 15; vec[1] = 22; outIt.Set( vec ); ++outIt;
102 //Vector no. 8
103 vec[0] = 9; vec[1] = 10; outIt.Set( vec ); ++outIt;
104 //Vector no. 9
105 vec[0] = 19; vec[1] = 10; outIt.Set( vec ); ++outIt;
106 //Vector no. 10
107 vec[0] = 19; vec[1] = 10; outIt.Set( vec ); ++outIt;
108 //Vector no. 11
109 vec[0] = 21; vec[1] = 21; outIt.Set( vec ); ++outIt;
110 //Vector no. 12
111 vec[0] = 11; vec[1] = 20; outIt.Set( vec ); ++outIt;
112 //Vector no. 13
113 vec[0] = 8; vec[1] = 18; outIt.Set( vec ); ++outIt;
114 //Vector no. 14
115 vec[0] = 18; vec[1] = 10; outIt.Set( vec ); ++outIt;
116 //Vector no. 15
117 vec[0] = 22; vec[1] = 10; outIt.Set( vec ); ++outIt;
118 //Vector no. 16
119 vec[0] = 24; vec[1] = 23; outIt.Set( vec ); ++outIt;
120
121 outIt.GoToBegin();
122
123 //---------------------------------------------------------------
124 //Input the codebook
125 //---------------------------------------------------------------
126 //------------------------------------------------------------------
127 //Read the codebook into an vnl_matrix
128 //------------------------------------------------------------------
129
130 vnl_matrix<double> inCDBK(NCODEWORDS, NUMBANDS);
131 //There are 4 entries to the code book
132 int r,c;
133 r=0; c=0; inCDBK.put(r,c,10);
134 r=0; c=1; inCDBK.put(r,c,10);
135 r=1; c=0; inCDBK.put(r,c,10);
136 r=1; c=1; inCDBK.put(r,c,20);
137 r=2; c=0; inCDBK.put(r,c,20);
138 r=2; c=1; inCDBK.put(r,c,10);
139 r=3; c=0; inCDBK.put(r,c,20);
140 r=3; c=1; inCDBK.put(r,c,20);
141
142 //----------------------------------------------------------------------
143 // Test code for the Kmeans model estimator
144 //----------------------------------------------------------------------
145
146 //---------------------------------------------------------------------
147 // Multiband data is now available in the right format
148 //---------------------------------------------------------------------
149
150 //----------------------------------------------------------------------
151 //Set membership function (Using the statistics objects)
152 //----------------------------------------------------------------------
153 namespace stat = itk::Statistics;
154
155 using MembershipFunctionType = stat::DistanceToCentroidMembershipFunction<VecImagePixelType>;
156 using MembershipFunctionPointer = MembershipFunctionType::Pointer;
157
158 using MembershipFunctionPointerVector = std::vector<MembershipFunctionPointer>;
159
160
161 //----------------------------------------------------------------------
162 //Set the image model estimator
163 //----------------------------------------------------------------------
164 using ImageKmeansModelEstimatorType =
165 itk::ImageKmeansModelEstimator<VecImageType, MembershipFunctionType>;
166
167 ImageKmeansModelEstimatorType::Pointer
168 applyKmeansEstimator = ImageKmeansModelEstimatorType::New();
169
170 //----------------------------------------------------------------------
171 //Set the parameters of the clusterer
172 //----------------------------------------------------------------------
173 applyKmeansEstimator->SetInputImage(vecImage);
174 applyKmeansEstimator->SetNumberOfModels(NCODEWORDS);
175 applyKmeansEstimator->SetThreshold(0.01 );
176 applyKmeansEstimator->SetOffsetAdd( 0.01 );
177 applyKmeansEstimator->SetOffsetMultiply( 0.01 );
178 applyKmeansEstimator->SetMaxSplitAttempts( 10 );
179 applyKmeansEstimator->Update();
180 applyKmeansEstimator->Print(std::cout);
181
182 MembershipFunctionPointerVector membershipFunctions =
183 applyKmeansEstimator->GetMembershipFunctions();
184
185 vnl_vector<double> kmeansResultForClass;
186 vnl_vector<double> referenceCodebookForClass;
187 vnl_vector<double> errorForClass;
188 double error =0;
189 double meanCDBKvalue = 0;
190
191 for(unsigned int classIndex=0; classIndex < membershipFunctions.size();
192 classIndex++ )
193 {
194 kmeansResultForClass = membershipFunctions[classIndex]->GetCentroid();
195 referenceCodebookForClass = inCDBK.get_row( classIndex);
196 errorForClass = kmeansResultForClass - referenceCodebookForClass;
197
198 for(int i = 0; i < NUMBANDS; i++)
199 {
200 error += itk::Math::abs(errorForClass[i]/referenceCodebookForClass[i]);
201 meanCDBKvalue += referenceCodebookForClass[i];
202 }
203
204 }
205 error /= NCODEWORDS*NUMBANDS;
206 meanCDBKvalue /= NCODEWORDS*NUMBANDS;
207
208 if( error < 0.1 * meanCDBKvalue)
209 std::cout << "Kmeans algorithm passed (without initial input)"<<std::endl;
210 else
211 std::cout << "Kmeans algorithm failed (without initial input)"<<std::endl;
212
213 //Validation with no codebook/initial Kmeans estimate
214 vnl_matrix<double> kmeansResult = applyKmeansEstimator->GetKmeansResults();
215 std::cout << "KMeansResults\n" << kmeansResult << std::endl;
216
217 applyKmeansEstimator->SetCodebook(inCDBK);
218 applyKmeansEstimator->Update();
219 applyKmeansEstimator->Print(std::cout);
220
221 membershipFunctions = applyKmeansEstimator->GetMembershipFunctions();
222
223 //Testing for the various parameter access functions in the test
224 std::cout << "The final codebook (cluster centers are: " << std::endl;
225 std::cout << applyKmeansEstimator->GetCodebook() << std::endl;
226 std::cout << "The threshold parameter used was: " <<
227 applyKmeansEstimator->GetThreshold() << std::endl;
228 std::cout << "The additive ofset parameter used was: " <<
229 applyKmeansEstimator->GetOffsetAdd() << std::endl;
230 std::cout << "The multiplicative ofset parameter used was: " <<
231 applyKmeansEstimator->GetOffsetMultiply() << std::endl;
232 std::cout << "The maximum number of attempted splits in codebook: " <<
233 applyKmeansEstimator->GetMaxSplitAttempts() << std::endl;
234 std::cout << " " << std::endl;
235
236 //Testing the distance of the first pixel to the centroids; identify the class
237 //closest to the fist pixel.
238 unsigned int minidx = 0;
239 double mindist = 99999999;
240 double classdist;
241 for( unsigned int idx=0; idx < membershipFunctions.size(); idx++ )
242 {
243 classdist = membershipFunctions[idx]->Evaluate( outIt.Get() );
244 std::cout << "Distance of first pixel to class " << idx << " is: " << classdist << std::endl;
245 if( mindist > classdist )
246 {
247 mindist = classdist;
248 minidx = idx;
249 }
250 }
251
252 //Validation with initial Kmeans estimate provided as input by the user
253 error =0;
254 meanCDBKvalue = 0;
255 const size_t test = membershipFunctions.size();
256 for(unsigned int classIndex=0; classIndex < test; classIndex++ )
257 {
258 kmeansResultForClass = membershipFunctions[classIndex]->GetCentroid();
259 referenceCodebookForClass = inCDBK.get_row( classIndex);
260 errorForClass = kmeansResultForClass - referenceCodebookForClass;
261
262 for(int i = 0; i < NUMBANDS; i++)
263 {
264 error += itk::Math::abs(errorForClass[i]/referenceCodebookForClass[i]);
265 meanCDBKvalue += referenceCodebookForClass[i];
266 }
267 }
268
269 error /= NCODEWORDS*NUMBANDS;
270 meanCDBKvalue /= NCODEWORDS*NUMBANDS;
271
272 //Check if the mean codebook is within error limits and the first pixel
273 //is labeled to belong to class 2
274 if( (error < 0.1 * meanCDBKvalue) && (minidx == 2) )
275 {
276 std::cout << "Kmeans algorithm passed (with initial input)"<<std::endl;
277 }
278 else
279 {
280 std::cout << "Kmeans algorithm failed (with initial input)"<<std::endl;
281 return EXIT_FAILURE;
282 }
283
284 return EXIT_SUCCESS;
285 }
286