1 #include "vtkKMeansStatistics.h"
2 #include "vtkKMeansAssessFunctor.h"
3 #include "vtkKMeansDistanceFunctor.h"
4 #include "vtkStringArray.h"
5 
6 #include "vtkDataObject.h"
7 #include "vtkDoubleArray.h"
8 #include "vtkIdTypeArray.h"
9 #include "vtkInformation.h"
10 #include "vtkIntArray.h"
11 #include "vtkMultiBlockDataSet.h"
12 #include "vtkObjectFactory.h"
13 #include "vtkStatisticsAlgorithmPrivate.h"
14 #include "vtkTable.h"
15 #include "vtkVariantArray.h"
16 
17 #include <map>
18 #include <sstream>
19 #include <vector>
20 
21 vtkStandardNewMacro(vtkKMeansStatistics);
22 vtkCxxSetObjectMacro(vtkKMeansStatistics, DistanceFunctor, vtkKMeansDistanceFunctor);
23 
24 //------------------------------------------------------------------------------
vtkKMeansStatistics()25 vtkKMeansStatistics::vtkKMeansStatistics()
26 {
27   this->AssessNames->SetNumberOfValues(2);
28   this->AssessNames->SetValue(0, "Distance");
29   this->AssessNames->SetValue(1, "ClosestId");
30   this->DefaultNumberOfClusters = 3;
31   this->Tolerance = 0.01;
32   this->KValuesArrayName = nullptr;
33   this->SetKValuesArrayName("K");
34   this->MaxNumIterations = 50;
35   this->DistanceFunctor = vtkKMeansDistanceFunctor::New();
36 }
37 
38 //------------------------------------------------------------------------------
~vtkKMeansStatistics()39 vtkKMeansStatistics::~vtkKMeansStatistics()
40 {
41   this->SetKValuesArrayName(nullptr);
42   this->SetDistanceFunctor(nullptr);
43 }
44 
45 //------------------------------------------------------------------------------
PrintSelf(ostream & os,vtkIndent indent)46 void vtkKMeansStatistics::PrintSelf(ostream& os, vtkIndent indent)
47 {
48   this->Superclass::PrintSelf(os, indent);
49   os << indent << "DefaultNumberofClusters: " << this->DefaultNumberOfClusters << endl;
50   os << indent << "KValuesArrayName: \""
51      << (this->KValuesArrayName ? this->KValuesArrayName : "nullptr") << "\"\n";
52   os << indent << "MaxNumIterations: " << this->MaxNumIterations << endl;
53   os << indent << "Tolerance: " << this->Tolerance << endl;
54   os << indent << "DistanceFunctor: " << this->DistanceFunctor << endl;
55 }
56 
57 //------------------------------------------------------------------------------
InitializeDataAndClusterCenters(vtkTable * inParameters,vtkTable * inData,vtkTable * dataElements,vtkIdTypeArray * numberOfClusters,vtkTable * curClusterElements,vtkTable * newClusterElements,vtkIdTypeArray * startRunID,vtkIdTypeArray * endRunID)58 int vtkKMeansStatistics::InitializeDataAndClusterCenters(vtkTable* inParameters, vtkTable* inData,
59   vtkTable* dataElements, vtkIdTypeArray* numberOfClusters, vtkTable* curClusterElements,
60   vtkTable* newClusterElements, vtkIdTypeArray* startRunID, vtkIdTypeArray* endRunID)
61 {
62   std::set<std::set<vtkStdString>>::const_iterator reqIt;
63   if (this->Internals->Requests.size() > 1)
64   {
65     static int num = 0;
66     num++;
67     if (num < 10)
68     {
69       vtkWarningMacro("Only the first request will be processed -- the rest will be ignored.");
70     }
71   }
72 
73   if (this->Internals->Requests.empty())
74   {
75     vtkErrorMacro("No requests were made.");
76     return 0;
77   }
78   reqIt = this->Internals->Requests.begin();
79 
80   vtkIdType numToAllocate;
81   vtkIdType numRuns = 0;
82 
83   int initialClusterCentersProvided = 0;
84 
85   // process parameter input table
86   if (inParameters && inParameters->GetNumberOfRows() > 0 && inParameters->GetNumberOfColumns() > 1)
87   {
88     vtkIdTypeArray* counts = vtkArrayDownCast<vtkIdTypeArray>(inParameters->GetColumn(0));
89     if (!counts)
90     {
91       vtkWarningMacro("The first column of the input parameter table should be of vtkIdType."
92         << endl
93         << "The input table provided will be ignored and a single run will be performed using the "
94            "first "
95         << this->DefaultNumberOfClusters << " observations as the initial cluster centers.");
96     }
97     else
98     {
99       initialClusterCentersProvided = 1;
100       numToAllocate = inParameters->GetNumberOfRows();
101       numberOfClusters->SetNumberOfValues(numToAllocate);
102       numberOfClusters->SetName(inParameters->GetColumn(0)->GetName());
103 
104       for (vtkIdType i = 0; i < numToAllocate; ++i)
105       {
106         numberOfClusters->SetValue(i, counts->GetValue(i));
107       }
108       vtkIdType curRow = 0;
109       while (curRow < inParameters->GetNumberOfRows())
110       {
111         numRuns++;
112         startRunID->InsertNextValue(curRow);
113         curRow += inParameters->GetValue(curRow, 0).ToInt();
114         endRunID->InsertNextValue(curRow);
115       }
116       vtkTable* condensedTable = vtkTable::New();
117       std::set<vtkStdString>::const_iterator colItr;
118       for (colItr = reqIt->begin(); colItr != reqIt->end(); ++colItr)
119       {
120         vtkAbstractArray* pArr = inParameters->GetColumnByName(colItr->c_str());
121         vtkAbstractArray* dArr = inData->GetColumnByName(colItr->c_str());
122         if (pArr && dArr)
123         {
124           condensedTable->AddColumn(pArr);
125           dataElements->AddColumn(dArr);
126         }
127         else
128         {
129           vtkWarningMacro("Skipping requested column \"" << colItr->c_str() << "\".");
130         }
131       }
132       newClusterElements->DeepCopy(condensedTable);
133       curClusterElements->DeepCopy(condensedTable);
134       condensedTable->Delete();
135     }
136   }
137   if (!initialClusterCentersProvided)
138   {
139     // otherwise create an initial set of cluster coords
140     numRuns = 1;
141     numToAllocate = this->DefaultNumberOfClusters < inData->GetNumberOfRows()
142       ? this->DefaultNumberOfClusters
143       : inData->GetNumberOfRows();
144     startRunID->InsertNextValue(0);
145     endRunID->InsertNextValue(numToAllocate);
146     numberOfClusters->SetName(this->KValuesArrayName);
147 
148     for (vtkIdType j = 0; j < inData->GetNumberOfColumns(); j++)
149     {
150       if (reqIt->find(inData->GetColumnName(j)) != reqIt->end())
151       {
152         vtkAbstractArray* curCoords = this->DistanceFunctor->CreateCoordinateArray();
153         vtkAbstractArray* newCoords = this->DistanceFunctor->CreateCoordinateArray();
154         curCoords->SetName(inData->GetColumnName(j));
155         newCoords->SetName(inData->GetColumnName(j));
156         curClusterElements->AddColumn(curCoords);
157         newClusterElements->AddColumn(newCoords);
158         curCoords->Delete();
159         newCoords->Delete();
160         dataElements->AddColumn(inData->GetColumnByName(inData->GetColumnName(j)));
161       }
162     }
163     CreateInitialClusterCenters(
164       numToAllocate, numberOfClusters, inData, curClusterElements, newClusterElements);
165   }
166 
167   if (curClusterElements->GetNumberOfColumns() == 0)
168   {
169     return 0;
170   }
171   return numRuns;
172 }
173 
174 //------------------------------------------------------------------------------
CreateInitialClusterCenters(vtkIdType numToAllocate,vtkIdTypeArray * numberOfClusters,vtkTable * inData,vtkTable * curClusterElements,vtkTable * newClusterElements)175 void vtkKMeansStatistics::CreateInitialClusterCenters(vtkIdType numToAllocate,
176   vtkIdTypeArray* numberOfClusters, vtkTable* inData, vtkTable* curClusterElements,
177   vtkTable* newClusterElements)
178 {
179   std::set<std::set<vtkStdString>>::const_iterator reqIt;
180   if (this->Internals->Requests.size() > 1)
181   {
182     static int num = 0;
183     ++num;
184     if (num < 10)
185     {
186       vtkWarningMacro("Only the first request will be processed -- the rest will be ignored.");
187     }
188   }
189 
190   if (this->Internals->Requests.empty())
191   {
192     vtkErrorMacro("No requests were made.");
193     return;
194   }
195   reqIt = this->Internals->Requests.begin();
196 
197   for (vtkIdType i = 0; i < numToAllocate; ++i)
198   {
199     numberOfClusters->InsertNextValue(numToAllocate);
200     vtkVariantArray* curRow = vtkVariantArray::New();
201     vtkVariantArray* newRow = vtkVariantArray::New();
202     for (int j = 0; j < inData->GetNumberOfColumns(); j++)
203     {
204       if (reqIt->find(inData->GetColumnName(j)) != reqIt->end())
205       {
206         curRow->InsertNextValue(inData->GetValue(i, j));
207         newRow->InsertNextValue(inData->GetValue(i, j));
208       }
209     }
210     curClusterElements->InsertNextRow(curRow);
211     newClusterElements->InsertNextRow(newRow);
212     curRow->Delete();
213     newRow->Delete();
214   }
215 }
216 
217 //------------------------------------------------------------------------------
GetTotalNumberOfObservations(vtkIdType numObservations)218 vtkIdType vtkKMeansStatistics::GetTotalNumberOfObservations(vtkIdType numObservations)
219 {
220   return numObservations;
221 }
222 
223 //------------------------------------------------------------------------------
UpdateClusterCenters(vtkTable * newClusterElements,vtkTable * curClusterElements,vtkIdTypeArray * vtkNotUsed (numMembershipChanges),vtkIdTypeArray * numDataElementsInCluster,vtkDoubleArray * vtkNotUsed (error),vtkIdTypeArray * startRunID,vtkIdTypeArray * endRunID,vtkIntArray * computeRun)224 void vtkKMeansStatistics::UpdateClusterCenters(vtkTable* newClusterElements,
225   vtkTable* curClusterElements, vtkIdTypeArray* vtkNotUsed(numMembershipChanges),
226   vtkIdTypeArray* numDataElementsInCluster, vtkDoubleArray* vtkNotUsed(error),
227   vtkIdTypeArray* startRunID, vtkIdTypeArray* endRunID, vtkIntArray* computeRun)
228 {
229   for (vtkIdType runID = 0; runID < startRunID->GetNumberOfTuples(); runID++)
230   {
231     if (computeRun->GetValue(runID))
232     {
233       for (vtkIdType i = startRunID->GetValue(runID); i < endRunID->GetValue(runID); ++i)
234       {
235         if (numDataElementsInCluster->GetValue(i) == 0)
236         {
237           vtkWarningMacro("cluster center " << i - startRunID->GetValue(runID) << " in run "
238                                             << runID << " is degenerate. Attempting to perturb");
239           this->DistanceFunctor->PerturbElement(newClusterElements, curClusterElements, i,
240             startRunID->GetValue(runID), endRunID->GetValue(runID), 0.8);
241         }
242       }
243     }
244   }
245 }
246 
247 //------------------------------------------------------------------------------
SetParameter(const char * parameter,int vtkNotUsed (index),vtkVariant value)248 bool vtkKMeansStatistics::SetParameter(
249   const char* parameter, int vtkNotUsed(index), vtkVariant value)
250 {
251   if (!parameter)
252     return false;
253 
254   vtkStdString pname = parameter;
255   if (pname == "DefaultNumberOfClusters" || pname == "k" || pname == "K")
256   {
257     bool valid;
258     int k = value.ToInt(&valid);
259     if (valid && k > 0)
260     {
261       this->SetDefaultNumberOfClusters(k);
262       return true;
263     }
264   }
265   else if (pname == "Tolerance")
266   {
267     double tol = value.ToDouble();
268     this->SetTolerance(tol);
269     return true;
270   }
271   else if (pname == "MaxNumIterations")
272   {
273     bool valid;
274     int maxit = value.ToInt(&valid);
275     if (valid && maxit >= 0)
276     {
277       this->SetMaxNumIterations(maxit);
278       return true;
279     }
280   }
281 
282   return false;
283 }
284 
285 //------------------------------------------------------------------------------
Learn(vtkTable * inData,vtkTable * inParameters,vtkMultiBlockDataSet * outMeta)286 void vtkKMeansStatistics::Learn(
287   vtkTable* inData, vtkTable* inParameters, vtkMultiBlockDataSet* outMeta)
288 {
289   if (!outMeta)
290   {
291     return;
292   }
293 
294   if (!inData)
295   {
296     return;
297   }
298 
299   if (!this->DistanceFunctor)
300   {
301     vtkErrorMacro("Distance functor is nullptr");
302     return;
303   }
304 
305   // Data initialization
306   vtkIdTypeArray* numberOfClusters = vtkIdTypeArray::New();
307   vtkTable* curClusterElements = vtkTable::New();
308   vtkTable* newClusterElements = vtkTable::New();
309   vtkIdTypeArray* startRunID = vtkIdTypeArray::New();
310   vtkIdTypeArray* endRunID = vtkIdTypeArray::New();
311   vtkTable* dataElements = vtkTable::New();
312   int numRuns = InitializeDataAndClusterCenters(inParameters, inData, dataElements,
313     numberOfClusters, curClusterElements, newClusterElements, startRunID, endRunID);
314   if (numRuns == 0)
315   {
316     numberOfClusters->Delete();
317     curClusterElements->Delete();
318     newClusterElements->Delete();
319     startRunID->Delete();
320     endRunID->Delete();
321     dataElements->Delete();
322     return;
323   }
324 
325   vtkIdType numObservations = inData->GetNumberOfRows();
326   vtkIdType totalNumberOfObservations = this->GetTotalNumberOfObservations(numObservations);
327   vtkIdType numToAllocate = curClusterElements->GetNumberOfRows();
328   vtkIdTypeArray* numIterations = vtkIdTypeArray::New();
329   vtkIdTypeArray* numDataElementsInCluster = vtkIdTypeArray::New();
330   vtkDoubleArray* error = vtkDoubleArray::New();
331   vtkIdTypeArray* clusterMemberID = vtkIdTypeArray::New();
332   vtkIdTypeArray* numMembershipChanges = vtkIdTypeArray::New();
333   vtkIntArray* computeRun = vtkIntArray::New();
334   vtkIdTypeArray* clusterRunIDs = vtkIdTypeArray::New();
335 
336   numDataElementsInCluster->SetNumberOfValues(numToAllocate);
337   numDataElementsInCluster->SetName("Cardinality");
338   clusterRunIDs->SetNumberOfValues(numToAllocate);
339   clusterRunIDs->SetName("Run ID");
340   error->SetNumberOfValues(numToAllocate);
341   error->SetName("Error");
342   numIterations->SetNumberOfValues(numToAllocate);
343   numIterations->SetName("Iterations");
344   numMembershipChanges->SetNumberOfValues(numRuns);
345   computeRun->SetNumberOfValues(numRuns);
346   clusterMemberID->SetNumberOfValues(numObservations * numRuns);
347   clusterMemberID->SetName("cluster member id");
348 
349   for (int i = 0; i < numRuns; ++i)
350   {
351     for (vtkIdType j = startRunID->GetValue(i); j < endRunID->GetValue(i); j++)
352     {
353       clusterRunIDs->SetValue(j, i);
354     }
355   }
356 
357   numIterations->FillComponent(0, 0);
358   computeRun->FillComponent(0, 1);
359   int allConverged, numIter = 0;
360   clusterMemberID->FillComponent(0, -1);
361 
362   // Iterate until new cluster centers have converged OR we have reached a max number of iterations
363   do
364   {
365     // Initialize coordinates, cluster sizes and errors
366     numMembershipChanges->FillComponent(0, 0);
367     for (int runID = 0; runID < numRuns; runID++)
368     {
369       if (computeRun->GetValue(runID))
370       {
371         for (vtkIdType j = startRunID->GetValue(runID); j < endRunID->GetValue(runID); j++)
372         {
373           curClusterElements->SetRow(j, newClusterElements->GetRow(j));
374           newClusterElements->SetRow(
375             j, this->DistanceFunctor->GetEmptyTuple(newClusterElements->GetNumberOfColumns()));
376           numDataElementsInCluster->SetValue(j, 0);
377           error->SetValue(j, 0.0);
378         }
379       }
380     }
381 
382     // Find minimum distance between each observation and each cluster center,
383     // then assign the observation to the nearest cluster.
384     vtkIdType localMemberID, offsetLocalMemberID;
385     double minDistance, curDistance;
386     for (vtkIdType observation = 0; observation < dataElements->GetNumberOfRows(); observation++)
387     {
388       for (int runID = 0; runID < numRuns; runID++)
389       {
390         if (computeRun->GetValue(runID))
391         {
392           vtkIdType runStartIdx = startRunID->GetValue(runID);
393           vtkIdType runEndIdx = endRunID->GetValue(runID);
394           if (runStartIdx >= runEndIdx)
395           {
396             continue;
397           }
398           vtkIdType j = runStartIdx;
399           localMemberID = 0;
400           offsetLocalMemberID = runStartIdx;
401           (*this->DistanceFunctor)(
402             minDistance, curClusterElements->GetRow(j), dataElements->GetRow(observation));
403           curDistance = minDistance;
404           ++j;
405           for (/* no init */; j < runEndIdx; j++)
406           {
407             (*this->DistanceFunctor)(
408               curDistance, curClusterElements->GetRow(j), dataElements->GetRow(observation));
409             if (curDistance < minDistance)
410             {
411               minDistance = curDistance;
412               localMemberID = j - runStartIdx;
413               offsetLocalMemberID = j;
414             }
415           }
416           // We've located the nearest cluster center. Has it changed since the last iteration?
417           if (clusterMemberID->GetValue(observation * numRuns + runID) != localMemberID)
418           {
419             numMembershipChanges->SetValue(runID, numMembershipChanges->GetValue(runID) + 1);
420             clusterMemberID->SetValue(observation * numRuns + runID, localMemberID);
421           }
422           // Give the distance functor a chance to modify any derived quantities used to
423           // change the cluster centers between iterations, now that we know which cluster
424           // center the observation is assigned to.
425           vtkIdType newCardinality = numDataElementsInCluster->GetValue(offsetLocalMemberID) + 1;
426           numDataElementsInCluster->SetValue(offsetLocalMemberID, newCardinality);
427           this->DistanceFunctor->PairwiseUpdate(newClusterElements, offsetLocalMemberID,
428             dataElements->GetRow(observation), 1, newCardinality);
429           // Update the error for this cluster center to account for this observation.
430           error->SetValue(offsetLocalMemberID, error->GetValue(offsetLocalMemberID) + minDistance);
431         }
432       }
433     }
434     // update cluster centers
435     this->UpdateClusterCenters(newClusterElements, curClusterElements, numMembershipChanges,
436       numDataElementsInCluster, error, startRunID, endRunID, computeRun);
437 
438     // check for convergence
439     numIter++;
440     allConverged = 0;
441 
442     for (int j = 0; j < numRuns; j++)
443     {
444       if (computeRun->GetValue(j))
445       {
446         double percentChanged = static_cast<double>(numMembershipChanges->GetValue(j)) /
447           static_cast<double>(totalNumberOfObservations);
448         if (percentChanged < this->Tolerance || numIter == this->MaxNumIterations)
449         {
450           allConverged++;
451           computeRun->SetValue(j, 0);
452           for (int k = startRunID->GetValue(j); k < endRunID->GetValue(j); k++)
453           {
454             numIterations->SetValue(k, numIter);
455           }
456         }
457       }
458       else
459       {
460         allConverged++;
461       }
462     }
463   } while (allConverged < numRuns && numIter < this->MaxNumIterations);
464 
465   // add columns to output table
466   vtkTable* outputTable = vtkTable::New();
467   outputTable->AddColumn(clusterRunIDs);
468   outputTable->AddColumn(numberOfClusters);
469   outputTable->AddColumn(numIterations);
470   outputTable->AddColumn(error);
471   outputTable->AddColumn(numDataElementsInCluster);
472   for (vtkIdType i = 0; i < newClusterElements->GetNumberOfColumns(); ++i)
473   {
474     outputTable->AddColumn(newClusterElements->GetColumn(i));
475   }
476 
477   outMeta->SetNumberOfBlocks(1);
478   outMeta->SetBlock(0, outputTable);
479   outMeta->GetMetaData(static_cast<unsigned>(0))
480     ->Set(vtkCompositeDataSet::NAME(), "Updated Cluster Centers");
481 
482   clusterRunIDs->Delete();
483   numberOfClusters->Delete();
484   numDataElementsInCluster->Delete();
485   numIterations->Delete();
486   error->Delete();
487   curClusterElements->Delete();
488   newClusterElements->Delete();
489   dataElements->Delete();
490   clusterMemberID->Delete();
491   outputTable->Delete();
492   startRunID->Delete();
493   endRunID->Delete();
494   computeRun->Delete();
495   numMembershipChanges->Delete();
496 }
497 
498 //------------------------------------------------------------------------------
Derive(vtkMultiBlockDataSet * outMeta)499 void vtkKMeansStatistics::Derive(vtkMultiBlockDataSet* outMeta)
500 {
501   vtkTable* outTable;
502   vtkIdTypeArray* clusterRunIDs;
503   vtkIdTypeArray* numIterations;
504   vtkIdTypeArray* numberOfClusters;
505   vtkDoubleArray* error;
506 
507   if (!outMeta || !(outTable = vtkTable::SafeDownCast(outMeta->GetBlock(0))) ||
508     !(clusterRunIDs = vtkArrayDownCast<vtkIdTypeArray>(outTable->GetColumn(0))) ||
509     !(numberOfClusters = vtkArrayDownCast<vtkIdTypeArray>(outTable->GetColumn(1))) ||
510     !(numIterations = vtkArrayDownCast<vtkIdTypeArray>(outTable->GetColumn(2))) ||
511     !(error = vtkArrayDownCast<vtkDoubleArray>(outTable->GetColumn(3))))
512   {
513     return;
514   }
515 
516   // Create an output table
517   // outMeta and which is presumed to exist upon entry to Derive).
518 
519   outMeta->SetNumberOfBlocks(2);
520 
521   vtkIdTypeArray* totalClusterRunIDs = vtkIdTypeArray::New();
522   vtkIdTypeArray* totalNumberOfClusters = vtkIdTypeArray::New();
523   vtkIdTypeArray* totalNumIterations = vtkIdTypeArray::New();
524   vtkIdTypeArray* globalRank = vtkIdTypeArray::New();
525   vtkIdTypeArray* localRank = vtkIdTypeArray::New();
526   vtkDoubleArray* totalError = vtkDoubleArray::New();
527 
528   totalClusterRunIDs->SetName(clusterRunIDs->GetName());
529   totalNumberOfClusters->SetName(numberOfClusters->GetName());
530   totalNumIterations->SetName(numIterations->GetName());
531   totalError->SetName("Total Error");
532   globalRank->SetName("Global Rank");
533   localRank->SetName("Local Rank");
534 
535   std::multimap<double, vtkIdType> globalErrorMap;
536   std::map<vtkIdType, std::multimap<double, vtkIdType>> localErrorMap;
537 
538   vtkIdType curRow = 0;
539   while (curRow < outTable->GetNumberOfRows())
540   {
541     totalClusterRunIDs->InsertNextValue(clusterRunIDs->GetValue(curRow));
542     totalNumIterations->InsertNextValue(numIterations->GetValue(curRow));
543     totalNumberOfClusters->InsertNextValue(numberOfClusters->GetValue(curRow));
544     double totalErr = 0.0;
545     for (vtkIdType i = curRow; i < curRow + numberOfClusters->GetValue(curRow); ++i)
546     {
547       totalErr += error->GetValue(i);
548     }
549     totalError->InsertNextValue(totalErr);
550     globalErrorMap.insert(
551       std::multimap<double, vtkIdType>::value_type(totalErr, clusterRunIDs->GetValue(curRow)));
552     localErrorMap[numberOfClusters->GetValue(curRow)].insert(
553       std::multimap<double, vtkIdType>::value_type(totalErr, clusterRunIDs->GetValue(curRow)));
554     curRow += numberOfClusters->GetValue(curRow);
555   }
556 
557   globalRank->SetNumberOfValues(totalClusterRunIDs->GetNumberOfTuples());
558   localRank->SetNumberOfValues(totalClusterRunIDs->GetNumberOfTuples());
559   int rankID = 1;
560 
561   for (std::multimap<double, vtkIdType>::iterator itr = globalErrorMap.begin();
562        itr != globalErrorMap.end(); ++itr)
563   {
564     globalRank->SetValue(itr->second, rankID++);
565   }
566   for (std::map<vtkIdType, std::multimap<double, vtkIdType>>::iterator itr = localErrorMap.begin();
567        itr != localErrorMap.end(); ++itr)
568   {
569     rankID = 1;
570     for (std::multimap<double, vtkIdType>::iterator rItr = itr->second.begin();
571          rItr != itr->second.end(); ++rItr)
572     {
573       localRank->SetValue(rItr->second, rankID++);
574     }
575   }
576 
577   vtkTable* ranked = vtkTable::New();
578   outMeta->SetBlock(1, ranked);
579   outMeta->GetMetaData(static_cast<unsigned>(1))
580     ->Set(vtkCompositeDataSet::NAME(), "Ranked Cluster Centers");
581   ranked->Delete(); // outMeta now owns ranked
582   ranked->AddColumn(totalClusterRunIDs);
583   ranked->AddColumn(totalNumberOfClusters);
584   ranked->AddColumn(totalNumIterations);
585   ranked->AddColumn(totalError);
586   ranked->AddColumn(localRank);
587   ranked->AddColumn(globalRank);
588 
589   totalError->Delete();
590   localRank->Delete();
591   globalRank->Delete();
592   totalClusterRunIDs->Delete();
593   totalNumberOfClusters->Delete();
594   totalNumIterations->Delete();
595 }
596 
597 //------------------------------------------------------------------------------
Assess(vtkTable * inData,vtkMultiBlockDataSet * inMeta,vtkTable * outData)598 void vtkKMeansStatistics::Assess(vtkTable* inData, vtkMultiBlockDataSet* inMeta, vtkTable* outData)
599 {
600   if (!inData)
601   {
602     return;
603   }
604 
605   if (!inMeta)
606   {
607     return;
608   }
609 
610   // Add a column to the output data related to the each input datum wrt the model in the request.
611   // Column names of the metadata and input data are assumed to match.
612   // The output columns will be named "this->AssessNames->GetValue(0)(A,B,C)" where
613   // "A", "B", and "C" are the column names specified in the per-request metadata tables.
614   AssessFunctor* dfunc = nullptr;
615   // only one request allowed in when learning, so there will only be one
616   vtkTable* reqModel = vtkTable::SafeDownCast(inMeta->GetBlock(0));
617   if (!reqModel)
618   {
619     // silently skip invalid entries. Note we leave assessValues column in output data even when
620     // it's empty.
621     return;
622   }
623 
624   this->SelectAssessFunctor(inData, reqModel, nullptr, dfunc);
625   if (!dfunc)
626   {
627     vtkWarningMacro("Assessment could not be accommodated. Skipping.");
628     return;
629   }
630 
631   vtkKMeansAssessFunctor* kmfunc = static_cast<vtkKMeansAssessFunctor*>(dfunc);
632 
633   vtkIdType nv = this->AssessNames->GetNumberOfValues();
634   int numRuns = kmfunc->GetNumberOfRuns();
635   std::vector<vtkStdString> names(nv * numRuns);
636   vtkIdType nRow = inData->GetNumberOfRows();
637   for (int i = 0; i < numRuns; ++i)
638   {
639     for (vtkIdType v = 0; v < nv; ++v)
640     {
641       std::ostringstream assessColName;
642       assessColName << this->AssessNames->GetValue(v) << "(" << i << ")";
643 
644       vtkAbstractArray* assessValues;
645       if (v)
646       { // The "closest id" column for each request will always be integer-valued
647         assessValues = vtkIntArray::New();
648       }
649       else
650       { // We'll assume for now that the "distance" column for each request will be a real number.
651         assessValues = vtkDoubleArray::New();
652       }
653       names[i * nv + v] =
654         assessColName.str()
655           .c_str(); // Storing names to be able to use SetValueByName which is faster than SetValue
656       assessValues->SetName(names[i * nv + v]);
657       assessValues->SetNumberOfTuples(nRow);
658       outData->AddColumn(assessValues);
659       assessValues->Delete();
660     }
661   }
662 
663   // Assess each entry of the column
664   vtkDoubleArray* assessResult = vtkDoubleArray::New();
665   for (vtkIdType r = 0; r < nRow; ++r)
666   {
667     (*dfunc)(assessResult, r);
668     for (vtkIdType j = 0; j < nv * numRuns; ++j)
669     {
670       outData->SetValueByName(r, names[j], assessResult->GetValue(j));
671     }
672   }
673   assessResult->Delete();
674 
675   delete dfunc;
676 }
677 
678 //------------------------------------------------------------------------------
SelectAssessFunctor(vtkTable * inData,vtkDataObject * inMetaDO,vtkStringArray * vtkNotUsed (rowNames),AssessFunctor * & dfunc)679 void vtkKMeansStatistics::SelectAssessFunctor(vtkTable* inData, vtkDataObject* inMetaDO,
680   vtkStringArray* vtkNotUsed(rowNames), AssessFunctor*& dfunc)
681 {
682   (void)inData;
683 
684   dfunc = nullptr;
685   vtkTable* reqModel = vtkTable::SafeDownCast(inMetaDO);
686   if (!reqModel)
687   {
688     return;
689   }
690 
691   if (!this->DistanceFunctor)
692   {
693     vtkErrorMacro("Distance functor is nullptr");
694     return;
695   }
696 
697   vtkKMeansAssessFunctor* kmfunc = vtkKMeansAssessFunctor::New();
698 
699   if (!kmfunc->Initialize(inData, reqModel, this->DistanceFunctor))
700   {
701     delete kmfunc;
702     return;
703   }
704   dfunc = kmfunc;
705 }
706 
707 //------------------------------------------------------------------------------
New()708 vtkKMeansAssessFunctor* vtkKMeansAssessFunctor::New()
709 {
710   return new vtkKMeansAssessFunctor;
711 }
712 
713 //------------------------------------------------------------------------------
~vtkKMeansAssessFunctor()714 vtkKMeansAssessFunctor::~vtkKMeansAssessFunctor()
715 {
716   this->ClusterMemberIDs->Delete();
717   this->Distances->Delete();
718 }
719 
720 //------------------------------------------------------------------------------
Initialize(vtkTable * inData,vtkTable * inModel,vtkKMeansDistanceFunctor * dfunc)721 bool vtkKMeansAssessFunctor::Initialize(
722   vtkTable* inData, vtkTable* inModel, vtkKMeansDistanceFunctor* dfunc)
723 {
724   vtkIdType numObservations = inData->GetNumberOfRows();
725   vtkTable* dataElements = vtkTable::New();
726   vtkTable* curClusterElements = vtkTable::New();
727   vtkIdTypeArray* startRunID = vtkIdTypeArray::New();
728   vtkIdTypeArray* endRunID = vtkIdTypeArray::New();
729 
730   this->Distances = vtkDoubleArray::New();
731   this->ClusterMemberIDs = vtkIdTypeArray::New();
732   this->NumRuns = 0;
733 
734   // cluster coordinates start in column 5 of the inModel table
735   for (vtkIdType i = 5; i < inModel->GetNumberOfColumns(); ++i)
736   {
737     curClusterElements->AddColumn(inModel->GetColumn(i));
738     dataElements->AddColumn(inData->GetColumnByName(inModel->GetColumnName(i)));
739   }
740 
741   vtkIdType curRow = 0;
742   while (curRow < inModel->GetNumberOfRows())
743   {
744     this->NumRuns++;
745     startRunID->InsertNextValue(curRow);
746     // number of clusters "K" is stored in column 1 of the inModel table
747     curRow += inModel->GetValue(curRow, 1).ToInt();
748     endRunID->InsertNextValue(curRow);
749   }
750 
751   this->Distances->SetNumberOfValues(numObservations * this->NumRuns);
752   this->ClusterMemberIDs->SetNumberOfValues(numObservations * this->NumRuns);
753 
754   // find minimum distance between each data object and cluster center
755   for (vtkIdType observation = 0; observation < numObservations; ++observation)
756   {
757     for (int runID = 0; runID < this->NumRuns; ++runID)
758     {
759       vtkIdType runStartIdx = startRunID->GetValue(runID);
760       vtkIdType runEndIdx = endRunID->GetValue(runID);
761       if (runStartIdx >= runEndIdx)
762       {
763         continue;
764       }
765       // Find the closest cluster center to the observation across all centers in the runID-th run.
766       vtkIdType j = runStartIdx;
767       double minDistance = 0.0;
768       double curDistance = 0.0;
769       (*dfunc)(minDistance, curClusterElements->GetRow(j), dataElements->GetRow(observation));
770       vtkIdType localMemberID = 0;
771       for (/* no init */; j < runEndIdx; ++j)
772       {
773         (*dfunc)(curDistance, curClusterElements->GetRow(j), dataElements->GetRow(observation));
774         if (curDistance < minDistance)
775         {
776           minDistance = curDistance;
777           localMemberID = j - runStartIdx;
778         }
779       }
780       this->ClusterMemberIDs->SetValue(observation * this->NumRuns + runID, localMemberID);
781       this->Distances->SetValue(observation * this->NumRuns + runID, minDistance);
782     }
783   }
784 
785   dataElements->Delete();
786   curClusterElements->Delete();
787   startRunID->Delete();
788   endRunID->Delete();
789   return true;
790 }
791 
792 //------------------------------------------------------------------------------
operator ()(vtkDoubleArray * result,vtkIdType row)793 void vtkKMeansAssessFunctor::operator()(vtkDoubleArray* result, vtkIdType row)
794 {
795 
796   result->SetNumberOfValues(2 * this->NumRuns);
797   vtkIdType resIndex = 0;
798   for (int runID = 0; runID < this->NumRuns; runID++)
799   {
800     result->SetValue(resIndex++, this->Distances->GetValue(row * this->NumRuns + runID));
801     result->SetValue(resIndex++, this->ClusterMemberIDs->GetValue(row * this->NumRuns + runID));
802   }
803 }
804