1 #include "vtkKMeansStatistics.h"
2 #include "vtkKMeansAssessFunctor.h"
3 #include "vtkKMeansDistanceFunctor.h"
4 #include "vtkStringArray.h"
5
6 #include "vtkDataObject.h"
7 #include "vtkDoubleArray.h"
8 #include "vtkIdTypeArray.h"
9 #include "vtkInformation.h"
10 #include "vtkIntArray.h"
11 #include "vtkMultiBlockDataSet.h"
12 #include "vtkObjectFactory.h"
13 #include "vtkStatisticsAlgorithmPrivate.h"
14 #include "vtkTable.h"
15 #include "vtkVariantArray.h"
16
17 #include <map>
18 #include <sstream>
19 #include <vector>
20
21 vtkStandardNewMacro(vtkKMeansStatistics);
22 vtkCxxSetObjectMacro(vtkKMeansStatistics, DistanceFunctor, vtkKMeansDistanceFunctor);
23
24 //------------------------------------------------------------------------------
vtkKMeansStatistics()25 vtkKMeansStatistics::vtkKMeansStatistics()
26 {
27 this->AssessNames->SetNumberOfValues(2);
28 this->AssessNames->SetValue(0, "Distance");
29 this->AssessNames->SetValue(1, "ClosestId");
30 this->DefaultNumberOfClusters = 3;
31 this->Tolerance = 0.01;
32 this->KValuesArrayName = nullptr;
33 this->SetKValuesArrayName("K");
34 this->MaxNumIterations = 50;
35 this->DistanceFunctor = vtkKMeansDistanceFunctor::New();
36 }
37
38 //------------------------------------------------------------------------------
~vtkKMeansStatistics()39 vtkKMeansStatistics::~vtkKMeansStatistics()
40 {
41 this->SetKValuesArrayName(nullptr);
42 this->SetDistanceFunctor(nullptr);
43 }
44
45 //------------------------------------------------------------------------------
PrintSelf(ostream & os,vtkIndent indent)46 void vtkKMeansStatistics::PrintSelf(ostream& os, vtkIndent indent)
47 {
48 this->Superclass::PrintSelf(os, indent);
49 os << indent << "DefaultNumberofClusters: " << this->DefaultNumberOfClusters << endl;
50 os << indent << "KValuesArrayName: \""
51 << (this->KValuesArrayName ? this->KValuesArrayName : "nullptr") << "\"\n";
52 os << indent << "MaxNumIterations: " << this->MaxNumIterations << endl;
53 os << indent << "Tolerance: " << this->Tolerance << endl;
54 os << indent << "DistanceFunctor: " << this->DistanceFunctor << endl;
55 }
56
57 //------------------------------------------------------------------------------
InitializeDataAndClusterCenters(vtkTable * inParameters,vtkTable * inData,vtkTable * dataElements,vtkIdTypeArray * numberOfClusters,vtkTable * curClusterElements,vtkTable * newClusterElements,vtkIdTypeArray * startRunID,vtkIdTypeArray * endRunID)58 int vtkKMeansStatistics::InitializeDataAndClusterCenters(vtkTable* inParameters, vtkTable* inData,
59 vtkTable* dataElements, vtkIdTypeArray* numberOfClusters, vtkTable* curClusterElements,
60 vtkTable* newClusterElements, vtkIdTypeArray* startRunID, vtkIdTypeArray* endRunID)
61 {
62 std::set<std::set<vtkStdString>>::const_iterator reqIt;
63 if (this->Internals->Requests.size() > 1)
64 {
65 static int num = 0;
66 num++;
67 if (num < 10)
68 {
69 vtkWarningMacro("Only the first request will be processed -- the rest will be ignored.");
70 }
71 }
72
73 if (this->Internals->Requests.empty())
74 {
75 vtkErrorMacro("No requests were made.");
76 return 0;
77 }
78 reqIt = this->Internals->Requests.begin();
79
80 vtkIdType numToAllocate;
81 vtkIdType numRuns = 0;
82
83 int initialClusterCentersProvided = 0;
84
85 // process parameter input table
86 if (inParameters && inParameters->GetNumberOfRows() > 0 && inParameters->GetNumberOfColumns() > 1)
87 {
88 vtkIdTypeArray* counts = vtkArrayDownCast<vtkIdTypeArray>(inParameters->GetColumn(0));
89 if (!counts)
90 {
91 vtkWarningMacro("The first column of the input parameter table should be of vtkIdType."
92 << endl
93 << "The input table provided will be ignored and a single run will be performed using the "
94 "first "
95 << this->DefaultNumberOfClusters << " observations as the initial cluster centers.");
96 }
97 else
98 {
99 initialClusterCentersProvided = 1;
100 numToAllocate = inParameters->GetNumberOfRows();
101 numberOfClusters->SetNumberOfValues(numToAllocate);
102 numberOfClusters->SetName(inParameters->GetColumn(0)->GetName());
103
104 for (vtkIdType i = 0; i < numToAllocate; ++i)
105 {
106 numberOfClusters->SetValue(i, counts->GetValue(i));
107 }
108 vtkIdType curRow = 0;
109 while (curRow < inParameters->GetNumberOfRows())
110 {
111 numRuns++;
112 startRunID->InsertNextValue(curRow);
113 curRow += inParameters->GetValue(curRow, 0).ToInt();
114 endRunID->InsertNextValue(curRow);
115 }
116 vtkTable* condensedTable = vtkTable::New();
117 std::set<vtkStdString>::const_iterator colItr;
118 for (colItr = reqIt->begin(); colItr != reqIt->end(); ++colItr)
119 {
120 vtkAbstractArray* pArr = inParameters->GetColumnByName(colItr->c_str());
121 vtkAbstractArray* dArr = inData->GetColumnByName(colItr->c_str());
122 if (pArr && dArr)
123 {
124 condensedTable->AddColumn(pArr);
125 dataElements->AddColumn(dArr);
126 }
127 else
128 {
129 vtkWarningMacro("Skipping requested column \"" << colItr->c_str() << "\".");
130 }
131 }
132 newClusterElements->DeepCopy(condensedTable);
133 curClusterElements->DeepCopy(condensedTable);
134 condensedTable->Delete();
135 }
136 }
137 if (!initialClusterCentersProvided)
138 {
139 // otherwise create an initial set of cluster coords
140 numRuns = 1;
141 numToAllocate = this->DefaultNumberOfClusters < inData->GetNumberOfRows()
142 ? this->DefaultNumberOfClusters
143 : inData->GetNumberOfRows();
144 startRunID->InsertNextValue(0);
145 endRunID->InsertNextValue(numToAllocate);
146 numberOfClusters->SetName(this->KValuesArrayName);
147
148 for (vtkIdType j = 0; j < inData->GetNumberOfColumns(); j++)
149 {
150 if (reqIt->find(inData->GetColumnName(j)) != reqIt->end())
151 {
152 vtkAbstractArray* curCoords = this->DistanceFunctor->CreateCoordinateArray();
153 vtkAbstractArray* newCoords = this->DistanceFunctor->CreateCoordinateArray();
154 curCoords->SetName(inData->GetColumnName(j));
155 newCoords->SetName(inData->GetColumnName(j));
156 curClusterElements->AddColumn(curCoords);
157 newClusterElements->AddColumn(newCoords);
158 curCoords->Delete();
159 newCoords->Delete();
160 dataElements->AddColumn(inData->GetColumnByName(inData->GetColumnName(j)));
161 }
162 }
163 CreateInitialClusterCenters(
164 numToAllocate, numberOfClusters, inData, curClusterElements, newClusterElements);
165 }
166
167 if (curClusterElements->GetNumberOfColumns() == 0)
168 {
169 return 0;
170 }
171 return numRuns;
172 }
173
174 //------------------------------------------------------------------------------
CreateInitialClusterCenters(vtkIdType numToAllocate,vtkIdTypeArray * numberOfClusters,vtkTable * inData,vtkTable * curClusterElements,vtkTable * newClusterElements)175 void vtkKMeansStatistics::CreateInitialClusterCenters(vtkIdType numToAllocate,
176 vtkIdTypeArray* numberOfClusters, vtkTable* inData, vtkTable* curClusterElements,
177 vtkTable* newClusterElements)
178 {
179 std::set<std::set<vtkStdString>>::const_iterator reqIt;
180 if (this->Internals->Requests.size() > 1)
181 {
182 static int num = 0;
183 ++num;
184 if (num < 10)
185 {
186 vtkWarningMacro("Only the first request will be processed -- the rest will be ignored.");
187 }
188 }
189
190 if (this->Internals->Requests.empty())
191 {
192 vtkErrorMacro("No requests were made.");
193 return;
194 }
195 reqIt = this->Internals->Requests.begin();
196
197 for (vtkIdType i = 0; i < numToAllocate; ++i)
198 {
199 numberOfClusters->InsertNextValue(numToAllocate);
200 vtkVariantArray* curRow = vtkVariantArray::New();
201 vtkVariantArray* newRow = vtkVariantArray::New();
202 for (int j = 0; j < inData->GetNumberOfColumns(); j++)
203 {
204 if (reqIt->find(inData->GetColumnName(j)) != reqIt->end())
205 {
206 curRow->InsertNextValue(inData->GetValue(i, j));
207 newRow->InsertNextValue(inData->GetValue(i, j));
208 }
209 }
210 curClusterElements->InsertNextRow(curRow);
211 newClusterElements->InsertNextRow(newRow);
212 curRow->Delete();
213 newRow->Delete();
214 }
215 }
216
217 //------------------------------------------------------------------------------
GetTotalNumberOfObservations(vtkIdType numObservations)218 vtkIdType vtkKMeansStatistics::GetTotalNumberOfObservations(vtkIdType numObservations)
219 {
220 return numObservations;
221 }
222
223 //------------------------------------------------------------------------------
UpdateClusterCenters(vtkTable * newClusterElements,vtkTable * curClusterElements,vtkIdTypeArray * vtkNotUsed (numMembershipChanges),vtkIdTypeArray * numDataElementsInCluster,vtkDoubleArray * vtkNotUsed (error),vtkIdTypeArray * startRunID,vtkIdTypeArray * endRunID,vtkIntArray * computeRun)224 void vtkKMeansStatistics::UpdateClusterCenters(vtkTable* newClusterElements,
225 vtkTable* curClusterElements, vtkIdTypeArray* vtkNotUsed(numMembershipChanges),
226 vtkIdTypeArray* numDataElementsInCluster, vtkDoubleArray* vtkNotUsed(error),
227 vtkIdTypeArray* startRunID, vtkIdTypeArray* endRunID, vtkIntArray* computeRun)
228 {
229 for (vtkIdType runID = 0; runID < startRunID->GetNumberOfTuples(); runID++)
230 {
231 if (computeRun->GetValue(runID))
232 {
233 for (vtkIdType i = startRunID->GetValue(runID); i < endRunID->GetValue(runID); ++i)
234 {
235 if (numDataElementsInCluster->GetValue(i) == 0)
236 {
237 vtkWarningMacro("cluster center " << i - startRunID->GetValue(runID) << " in run "
238 << runID << " is degenerate. Attempting to perturb");
239 this->DistanceFunctor->PerturbElement(newClusterElements, curClusterElements, i,
240 startRunID->GetValue(runID), endRunID->GetValue(runID), 0.8);
241 }
242 }
243 }
244 }
245 }
246
247 //------------------------------------------------------------------------------
SetParameter(const char * parameter,int vtkNotUsed (index),vtkVariant value)248 bool vtkKMeansStatistics::SetParameter(
249 const char* parameter, int vtkNotUsed(index), vtkVariant value)
250 {
251 if (!parameter)
252 return false;
253
254 vtkStdString pname = parameter;
255 if (pname == "DefaultNumberOfClusters" || pname == "k" || pname == "K")
256 {
257 bool valid;
258 int k = value.ToInt(&valid);
259 if (valid && k > 0)
260 {
261 this->SetDefaultNumberOfClusters(k);
262 return true;
263 }
264 }
265 else if (pname == "Tolerance")
266 {
267 double tol = value.ToDouble();
268 this->SetTolerance(tol);
269 return true;
270 }
271 else if (pname == "MaxNumIterations")
272 {
273 bool valid;
274 int maxit = value.ToInt(&valid);
275 if (valid && maxit >= 0)
276 {
277 this->SetMaxNumIterations(maxit);
278 return true;
279 }
280 }
281
282 return false;
283 }
284
285 //------------------------------------------------------------------------------
Learn(vtkTable * inData,vtkTable * inParameters,vtkMultiBlockDataSet * outMeta)286 void vtkKMeansStatistics::Learn(
287 vtkTable* inData, vtkTable* inParameters, vtkMultiBlockDataSet* outMeta)
288 {
289 if (!outMeta)
290 {
291 return;
292 }
293
294 if (!inData)
295 {
296 return;
297 }
298
299 if (!this->DistanceFunctor)
300 {
301 vtkErrorMacro("Distance functor is nullptr");
302 return;
303 }
304
305 // Data initialization
306 vtkIdTypeArray* numberOfClusters = vtkIdTypeArray::New();
307 vtkTable* curClusterElements = vtkTable::New();
308 vtkTable* newClusterElements = vtkTable::New();
309 vtkIdTypeArray* startRunID = vtkIdTypeArray::New();
310 vtkIdTypeArray* endRunID = vtkIdTypeArray::New();
311 vtkTable* dataElements = vtkTable::New();
312 int numRuns = InitializeDataAndClusterCenters(inParameters, inData, dataElements,
313 numberOfClusters, curClusterElements, newClusterElements, startRunID, endRunID);
314 if (numRuns == 0)
315 {
316 numberOfClusters->Delete();
317 curClusterElements->Delete();
318 newClusterElements->Delete();
319 startRunID->Delete();
320 endRunID->Delete();
321 dataElements->Delete();
322 return;
323 }
324
325 vtkIdType numObservations = inData->GetNumberOfRows();
326 vtkIdType totalNumberOfObservations = this->GetTotalNumberOfObservations(numObservations);
327 vtkIdType numToAllocate = curClusterElements->GetNumberOfRows();
328 vtkIdTypeArray* numIterations = vtkIdTypeArray::New();
329 vtkIdTypeArray* numDataElementsInCluster = vtkIdTypeArray::New();
330 vtkDoubleArray* error = vtkDoubleArray::New();
331 vtkIdTypeArray* clusterMemberID = vtkIdTypeArray::New();
332 vtkIdTypeArray* numMembershipChanges = vtkIdTypeArray::New();
333 vtkIntArray* computeRun = vtkIntArray::New();
334 vtkIdTypeArray* clusterRunIDs = vtkIdTypeArray::New();
335
336 numDataElementsInCluster->SetNumberOfValues(numToAllocate);
337 numDataElementsInCluster->SetName("Cardinality");
338 clusterRunIDs->SetNumberOfValues(numToAllocate);
339 clusterRunIDs->SetName("Run ID");
340 error->SetNumberOfValues(numToAllocate);
341 error->SetName("Error");
342 numIterations->SetNumberOfValues(numToAllocate);
343 numIterations->SetName("Iterations");
344 numMembershipChanges->SetNumberOfValues(numRuns);
345 computeRun->SetNumberOfValues(numRuns);
346 clusterMemberID->SetNumberOfValues(numObservations * numRuns);
347 clusterMemberID->SetName("cluster member id");
348
349 for (int i = 0; i < numRuns; ++i)
350 {
351 for (vtkIdType j = startRunID->GetValue(i); j < endRunID->GetValue(i); j++)
352 {
353 clusterRunIDs->SetValue(j, i);
354 }
355 }
356
357 numIterations->FillComponent(0, 0);
358 computeRun->FillComponent(0, 1);
359 int allConverged, numIter = 0;
360 clusterMemberID->FillComponent(0, -1);
361
362 // Iterate until new cluster centers have converged OR we have reached a max number of iterations
363 do
364 {
365 // Initialize coordinates, cluster sizes and errors
366 numMembershipChanges->FillComponent(0, 0);
367 for (int runID = 0; runID < numRuns; runID++)
368 {
369 if (computeRun->GetValue(runID))
370 {
371 for (vtkIdType j = startRunID->GetValue(runID); j < endRunID->GetValue(runID); j++)
372 {
373 curClusterElements->SetRow(j, newClusterElements->GetRow(j));
374 newClusterElements->SetRow(
375 j, this->DistanceFunctor->GetEmptyTuple(newClusterElements->GetNumberOfColumns()));
376 numDataElementsInCluster->SetValue(j, 0);
377 error->SetValue(j, 0.0);
378 }
379 }
380 }
381
382 // Find minimum distance between each observation and each cluster center,
383 // then assign the observation to the nearest cluster.
384 vtkIdType localMemberID, offsetLocalMemberID;
385 double minDistance, curDistance;
386 for (vtkIdType observation = 0; observation < dataElements->GetNumberOfRows(); observation++)
387 {
388 for (int runID = 0; runID < numRuns; runID++)
389 {
390 if (computeRun->GetValue(runID))
391 {
392 vtkIdType runStartIdx = startRunID->GetValue(runID);
393 vtkIdType runEndIdx = endRunID->GetValue(runID);
394 if (runStartIdx >= runEndIdx)
395 {
396 continue;
397 }
398 vtkIdType j = runStartIdx;
399 localMemberID = 0;
400 offsetLocalMemberID = runStartIdx;
401 (*this->DistanceFunctor)(
402 minDistance, curClusterElements->GetRow(j), dataElements->GetRow(observation));
403 curDistance = minDistance;
404 ++j;
405 for (/* no init */; j < runEndIdx; j++)
406 {
407 (*this->DistanceFunctor)(
408 curDistance, curClusterElements->GetRow(j), dataElements->GetRow(observation));
409 if (curDistance < minDistance)
410 {
411 minDistance = curDistance;
412 localMemberID = j - runStartIdx;
413 offsetLocalMemberID = j;
414 }
415 }
416 // We've located the nearest cluster center. Has it changed since the last iteration?
417 if (clusterMemberID->GetValue(observation * numRuns + runID) != localMemberID)
418 {
419 numMembershipChanges->SetValue(runID, numMembershipChanges->GetValue(runID) + 1);
420 clusterMemberID->SetValue(observation * numRuns + runID, localMemberID);
421 }
422 // Give the distance functor a chance to modify any derived quantities used to
423 // change the cluster centers between iterations, now that we know which cluster
424 // center the observation is assigned to.
425 vtkIdType newCardinality = numDataElementsInCluster->GetValue(offsetLocalMemberID) + 1;
426 numDataElementsInCluster->SetValue(offsetLocalMemberID, newCardinality);
427 this->DistanceFunctor->PairwiseUpdate(newClusterElements, offsetLocalMemberID,
428 dataElements->GetRow(observation), 1, newCardinality);
429 // Update the error for this cluster center to account for this observation.
430 error->SetValue(offsetLocalMemberID, error->GetValue(offsetLocalMemberID) + minDistance);
431 }
432 }
433 }
434 // update cluster centers
435 this->UpdateClusterCenters(newClusterElements, curClusterElements, numMembershipChanges,
436 numDataElementsInCluster, error, startRunID, endRunID, computeRun);
437
438 // check for convergence
439 numIter++;
440 allConverged = 0;
441
442 for (int j = 0; j < numRuns; j++)
443 {
444 if (computeRun->GetValue(j))
445 {
446 double percentChanged = static_cast<double>(numMembershipChanges->GetValue(j)) /
447 static_cast<double>(totalNumberOfObservations);
448 if (percentChanged < this->Tolerance || numIter == this->MaxNumIterations)
449 {
450 allConverged++;
451 computeRun->SetValue(j, 0);
452 for (int k = startRunID->GetValue(j); k < endRunID->GetValue(j); k++)
453 {
454 numIterations->SetValue(k, numIter);
455 }
456 }
457 }
458 else
459 {
460 allConverged++;
461 }
462 }
463 } while (allConverged < numRuns && numIter < this->MaxNumIterations);
464
465 // add columns to output table
466 vtkTable* outputTable = vtkTable::New();
467 outputTable->AddColumn(clusterRunIDs);
468 outputTable->AddColumn(numberOfClusters);
469 outputTable->AddColumn(numIterations);
470 outputTable->AddColumn(error);
471 outputTable->AddColumn(numDataElementsInCluster);
472 for (vtkIdType i = 0; i < newClusterElements->GetNumberOfColumns(); ++i)
473 {
474 outputTable->AddColumn(newClusterElements->GetColumn(i));
475 }
476
477 outMeta->SetNumberOfBlocks(1);
478 outMeta->SetBlock(0, outputTable);
479 outMeta->GetMetaData(static_cast<unsigned>(0))
480 ->Set(vtkCompositeDataSet::NAME(), "Updated Cluster Centers");
481
482 clusterRunIDs->Delete();
483 numberOfClusters->Delete();
484 numDataElementsInCluster->Delete();
485 numIterations->Delete();
486 error->Delete();
487 curClusterElements->Delete();
488 newClusterElements->Delete();
489 dataElements->Delete();
490 clusterMemberID->Delete();
491 outputTable->Delete();
492 startRunID->Delete();
493 endRunID->Delete();
494 computeRun->Delete();
495 numMembershipChanges->Delete();
496 }
497
498 //------------------------------------------------------------------------------
Derive(vtkMultiBlockDataSet * outMeta)499 void vtkKMeansStatistics::Derive(vtkMultiBlockDataSet* outMeta)
500 {
501 vtkTable* outTable;
502 vtkIdTypeArray* clusterRunIDs;
503 vtkIdTypeArray* numIterations;
504 vtkIdTypeArray* numberOfClusters;
505 vtkDoubleArray* error;
506
507 if (!outMeta || !(outTable = vtkTable::SafeDownCast(outMeta->GetBlock(0))) ||
508 !(clusterRunIDs = vtkArrayDownCast<vtkIdTypeArray>(outTable->GetColumn(0))) ||
509 !(numberOfClusters = vtkArrayDownCast<vtkIdTypeArray>(outTable->GetColumn(1))) ||
510 !(numIterations = vtkArrayDownCast<vtkIdTypeArray>(outTable->GetColumn(2))) ||
511 !(error = vtkArrayDownCast<vtkDoubleArray>(outTable->GetColumn(3))))
512 {
513 return;
514 }
515
516 // Create an output table
517 // outMeta and which is presumed to exist upon entry to Derive).
518
519 outMeta->SetNumberOfBlocks(2);
520
521 vtkIdTypeArray* totalClusterRunIDs = vtkIdTypeArray::New();
522 vtkIdTypeArray* totalNumberOfClusters = vtkIdTypeArray::New();
523 vtkIdTypeArray* totalNumIterations = vtkIdTypeArray::New();
524 vtkIdTypeArray* globalRank = vtkIdTypeArray::New();
525 vtkIdTypeArray* localRank = vtkIdTypeArray::New();
526 vtkDoubleArray* totalError = vtkDoubleArray::New();
527
528 totalClusterRunIDs->SetName(clusterRunIDs->GetName());
529 totalNumberOfClusters->SetName(numberOfClusters->GetName());
530 totalNumIterations->SetName(numIterations->GetName());
531 totalError->SetName("Total Error");
532 globalRank->SetName("Global Rank");
533 localRank->SetName("Local Rank");
534
535 std::multimap<double, vtkIdType> globalErrorMap;
536 std::map<vtkIdType, std::multimap<double, vtkIdType>> localErrorMap;
537
538 vtkIdType curRow = 0;
539 while (curRow < outTable->GetNumberOfRows())
540 {
541 totalClusterRunIDs->InsertNextValue(clusterRunIDs->GetValue(curRow));
542 totalNumIterations->InsertNextValue(numIterations->GetValue(curRow));
543 totalNumberOfClusters->InsertNextValue(numberOfClusters->GetValue(curRow));
544 double totalErr = 0.0;
545 for (vtkIdType i = curRow; i < curRow + numberOfClusters->GetValue(curRow); ++i)
546 {
547 totalErr += error->GetValue(i);
548 }
549 totalError->InsertNextValue(totalErr);
550 globalErrorMap.insert(
551 std::multimap<double, vtkIdType>::value_type(totalErr, clusterRunIDs->GetValue(curRow)));
552 localErrorMap[numberOfClusters->GetValue(curRow)].insert(
553 std::multimap<double, vtkIdType>::value_type(totalErr, clusterRunIDs->GetValue(curRow)));
554 curRow += numberOfClusters->GetValue(curRow);
555 }
556
557 globalRank->SetNumberOfValues(totalClusterRunIDs->GetNumberOfTuples());
558 localRank->SetNumberOfValues(totalClusterRunIDs->GetNumberOfTuples());
559 int rankID = 1;
560
561 for (std::multimap<double, vtkIdType>::iterator itr = globalErrorMap.begin();
562 itr != globalErrorMap.end(); ++itr)
563 {
564 globalRank->SetValue(itr->second, rankID++);
565 }
566 for (std::map<vtkIdType, std::multimap<double, vtkIdType>>::iterator itr = localErrorMap.begin();
567 itr != localErrorMap.end(); ++itr)
568 {
569 rankID = 1;
570 for (std::multimap<double, vtkIdType>::iterator rItr = itr->second.begin();
571 rItr != itr->second.end(); ++rItr)
572 {
573 localRank->SetValue(rItr->second, rankID++);
574 }
575 }
576
577 vtkTable* ranked = vtkTable::New();
578 outMeta->SetBlock(1, ranked);
579 outMeta->GetMetaData(static_cast<unsigned>(1))
580 ->Set(vtkCompositeDataSet::NAME(), "Ranked Cluster Centers");
581 ranked->Delete(); // outMeta now owns ranked
582 ranked->AddColumn(totalClusterRunIDs);
583 ranked->AddColumn(totalNumberOfClusters);
584 ranked->AddColumn(totalNumIterations);
585 ranked->AddColumn(totalError);
586 ranked->AddColumn(localRank);
587 ranked->AddColumn(globalRank);
588
589 totalError->Delete();
590 localRank->Delete();
591 globalRank->Delete();
592 totalClusterRunIDs->Delete();
593 totalNumberOfClusters->Delete();
594 totalNumIterations->Delete();
595 }
596
597 //------------------------------------------------------------------------------
Assess(vtkTable * inData,vtkMultiBlockDataSet * inMeta,vtkTable * outData)598 void vtkKMeansStatistics::Assess(vtkTable* inData, vtkMultiBlockDataSet* inMeta, vtkTable* outData)
599 {
600 if (!inData)
601 {
602 return;
603 }
604
605 if (!inMeta)
606 {
607 return;
608 }
609
610 // Add a column to the output data related to the each input datum wrt the model in the request.
611 // Column names of the metadata and input data are assumed to match.
612 // The output columns will be named "this->AssessNames->GetValue(0)(A,B,C)" where
613 // "A", "B", and "C" are the column names specified in the per-request metadata tables.
614 AssessFunctor* dfunc = nullptr;
615 // only one request allowed in when learning, so there will only be one
616 vtkTable* reqModel = vtkTable::SafeDownCast(inMeta->GetBlock(0));
617 if (!reqModel)
618 {
619 // silently skip invalid entries. Note we leave assessValues column in output data even when
620 // it's empty.
621 return;
622 }
623
624 this->SelectAssessFunctor(inData, reqModel, nullptr, dfunc);
625 if (!dfunc)
626 {
627 vtkWarningMacro("Assessment could not be accommodated. Skipping.");
628 return;
629 }
630
631 vtkKMeansAssessFunctor* kmfunc = static_cast<vtkKMeansAssessFunctor*>(dfunc);
632
633 vtkIdType nv = this->AssessNames->GetNumberOfValues();
634 int numRuns = kmfunc->GetNumberOfRuns();
635 std::vector<vtkStdString> names(nv * numRuns);
636 vtkIdType nRow = inData->GetNumberOfRows();
637 for (int i = 0; i < numRuns; ++i)
638 {
639 for (vtkIdType v = 0; v < nv; ++v)
640 {
641 std::ostringstream assessColName;
642 assessColName << this->AssessNames->GetValue(v) << "(" << i << ")";
643
644 vtkAbstractArray* assessValues;
645 if (v)
646 { // The "closest id" column for each request will always be integer-valued
647 assessValues = vtkIntArray::New();
648 }
649 else
650 { // We'll assume for now that the "distance" column for each request will be a real number.
651 assessValues = vtkDoubleArray::New();
652 }
653 names[i * nv + v] =
654 assessColName.str()
655 .c_str(); // Storing names to be able to use SetValueByName which is faster than SetValue
656 assessValues->SetName(names[i * nv + v]);
657 assessValues->SetNumberOfTuples(nRow);
658 outData->AddColumn(assessValues);
659 assessValues->Delete();
660 }
661 }
662
663 // Assess each entry of the column
664 vtkDoubleArray* assessResult = vtkDoubleArray::New();
665 for (vtkIdType r = 0; r < nRow; ++r)
666 {
667 (*dfunc)(assessResult, r);
668 for (vtkIdType j = 0; j < nv * numRuns; ++j)
669 {
670 outData->SetValueByName(r, names[j], assessResult->GetValue(j));
671 }
672 }
673 assessResult->Delete();
674
675 delete dfunc;
676 }
677
678 //------------------------------------------------------------------------------
SelectAssessFunctor(vtkTable * inData,vtkDataObject * inMetaDO,vtkStringArray * vtkNotUsed (rowNames),AssessFunctor * & dfunc)679 void vtkKMeansStatistics::SelectAssessFunctor(vtkTable* inData, vtkDataObject* inMetaDO,
680 vtkStringArray* vtkNotUsed(rowNames), AssessFunctor*& dfunc)
681 {
682 (void)inData;
683
684 dfunc = nullptr;
685 vtkTable* reqModel = vtkTable::SafeDownCast(inMetaDO);
686 if (!reqModel)
687 {
688 return;
689 }
690
691 if (!this->DistanceFunctor)
692 {
693 vtkErrorMacro("Distance functor is nullptr");
694 return;
695 }
696
697 vtkKMeansAssessFunctor* kmfunc = vtkKMeansAssessFunctor::New();
698
699 if (!kmfunc->Initialize(inData, reqModel, this->DistanceFunctor))
700 {
701 delete kmfunc;
702 return;
703 }
704 dfunc = kmfunc;
705 }
706
707 //------------------------------------------------------------------------------
New()708 vtkKMeansAssessFunctor* vtkKMeansAssessFunctor::New()
709 {
710 return new vtkKMeansAssessFunctor;
711 }
712
713 //------------------------------------------------------------------------------
~vtkKMeansAssessFunctor()714 vtkKMeansAssessFunctor::~vtkKMeansAssessFunctor()
715 {
716 this->ClusterMemberIDs->Delete();
717 this->Distances->Delete();
718 }
719
720 //------------------------------------------------------------------------------
Initialize(vtkTable * inData,vtkTable * inModel,vtkKMeansDistanceFunctor * dfunc)721 bool vtkKMeansAssessFunctor::Initialize(
722 vtkTable* inData, vtkTable* inModel, vtkKMeansDistanceFunctor* dfunc)
723 {
724 vtkIdType numObservations = inData->GetNumberOfRows();
725 vtkTable* dataElements = vtkTable::New();
726 vtkTable* curClusterElements = vtkTable::New();
727 vtkIdTypeArray* startRunID = vtkIdTypeArray::New();
728 vtkIdTypeArray* endRunID = vtkIdTypeArray::New();
729
730 this->Distances = vtkDoubleArray::New();
731 this->ClusterMemberIDs = vtkIdTypeArray::New();
732 this->NumRuns = 0;
733
734 // cluster coordinates start in column 5 of the inModel table
735 for (vtkIdType i = 5; i < inModel->GetNumberOfColumns(); ++i)
736 {
737 curClusterElements->AddColumn(inModel->GetColumn(i));
738 dataElements->AddColumn(inData->GetColumnByName(inModel->GetColumnName(i)));
739 }
740
741 vtkIdType curRow = 0;
742 while (curRow < inModel->GetNumberOfRows())
743 {
744 this->NumRuns++;
745 startRunID->InsertNextValue(curRow);
746 // number of clusters "K" is stored in column 1 of the inModel table
747 curRow += inModel->GetValue(curRow, 1).ToInt();
748 endRunID->InsertNextValue(curRow);
749 }
750
751 this->Distances->SetNumberOfValues(numObservations * this->NumRuns);
752 this->ClusterMemberIDs->SetNumberOfValues(numObservations * this->NumRuns);
753
754 // find minimum distance between each data object and cluster center
755 for (vtkIdType observation = 0; observation < numObservations; ++observation)
756 {
757 for (int runID = 0; runID < this->NumRuns; ++runID)
758 {
759 vtkIdType runStartIdx = startRunID->GetValue(runID);
760 vtkIdType runEndIdx = endRunID->GetValue(runID);
761 if (runStartIdx >= runEndIdx)
762 {
763 continue;
764 }
765 // Find the closest cluster center to the observation across all centers in the runID-th run.
766 vtkIdType j = runStartIdx;
767 double minDistance = 0.0;
768 double curDistance = 0.0;
769 (*dfunc)(minDistance, curClusterElements->GetRow(j), dataElements->GetRow(observation));
770 vtkIdType localMemberID = 0;
771 for (/* no init */; j < runEndIdx; ++j)
772 {
773 (*dfunc)(curDistance, curClusterElements->GetRow(j), dataElements->GetRow(observation));
774 if (curDistance < minDistance)
775 {
776 minDistance = curDistance;
777 localMemberID = j - runStartIdx;
778 }
779 }
780 this->ClusterMemberIDs->SetValue(observation * this->NumRuns + runID, localMemberID);
781 this->Distances->SetValue(observation * this->NumRuns + runID, minDistance);
782 }
783 }
784
785 dataElements->Delete();
786 curClusterElements->Delete();
787 startRunID->Delete();
788 endRunID->Delete();
789 return true;
790 }
791
792 //------------------------------------------------------------------------------
operator ()(vtkDoubleArray * result,vtkIdType row)793 void vtkKMeansAssessFunctor::operator()(vtkDoubleArray* result, vtkIdType row)
794 {
795
796 result->SetNumberOfValues(2 * this->NumRuns);
797 vtkIdType resIndex = 0;
798 for (int runID = 0; runID < this->NumRuns; runID++)
799 {
800 result->SetValue(resIndex++, this->Distances->GetValue(row * this->NumRuns + runID));
801 result->SetValue(resIndex++, this->ClusterMemberIDs->GetValue(row * this->NumRuns + runID));
802 }
803 }
804