1 /*=========================================================================
2 
3   Program:   Visualization Toolkit
4   Module:    vtkDotProductSimilarity.cxx
5 
6 -------------------------------------------------------------------------
7   Copyright 2008 Sandia Corporation.
8   Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
9   the U.S. Government retains certain rights in this software.
10 -------------------------------------------------------------------------
11 
12   Copyright (c) Ken Martin, Will Schroeder, Bill Lorensen
13   All rights reserved.
14   See Copyright.txt or http://www.kitware.com/Copyright.htm for details.
15 
16      This software is distributed WITHOUT ANY WARRANTY; without even
17      the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
18      PURPOSE.  See the above copyright notice for more information.
19 
20 =========================================================================*/
21 
22 #include "vtkDotProductSimilarity.h"
23 #include "vtkArrayData.h"
24 #include "vtkCommand.h"
25 #include "vtkDenseArray.h"
26 #include "vtkDoubleArray.h"
27 #include "vtkIdTypeArray.h"
28 #include "vtkInformation.h"
29 #include "vtkInformationVector.h"
30 #include "vtkObjectFactory.h"
31 #include "vtkSmartPointer.h"
32 #include "vtkTable.h"
33 
34 #include <algorithm>
35 #include <map>
36 #include <stdexcept>
37 
38 // threshold_multimap
39 // This strange little fellow is used by the vtkDotProductSimilarity
40 // implementation.  It provides the interface
41 // of a std::multimap, but it enforces several constraints on its contents:
42 //
43 // There is an upper-limit on the number of values stored in the container.
44 // There is a lower threshold on key-values stored in the container.
45 // The key threshold can be overridden by specifying a lower-limit on the
46 // number of values stored in the container.
47 
48 template <typename KeyT, typename ValueT>
49 class threshold_multimap : public std::multimap<KeyT, ValueT, std::less<KeyT>>
50 {
51   typedef std::multimap<KeyT, ValueT, std::less<KeyT>> container_t;
52 
53 public:
threshold_multimap(KeyT minimum_threshold,size_t minimum_count,size_t maximum_count)54   threshold_multimap(KeyT minimum_threshold, size_t minimum_count, size_t maximum_count)
55     : MinimumThreshold(minimum_threshold)
56     , MinimumCount(std::max(static_cast<size_t>(0), minimum_count))
57     , MaximumCount(std::max(static_cast<size_t>(0), maximum_count))
58   {
59   }
60 
insert(const typename container_t::value_type & value)61   void insert(const typename container_t::value_type& value)
62   {
63     // Insert the value into the container ...
64     container_t::insert(value);
65 
66     // Prune small values down to our minimum size ...
67     while ((this->size() > this->MinimumCount) && (this->begin()->first < this->MinimumThreshold))
68       this->erase(this->begin());
69 
70     // Prune small values down to our maximum size ...
71     while (this->size() > this->MaximumCount)
72       this->erase(this->begin());
73   }
74 
75 private:
76   typename container_t::iterator insert(
77     typename container_t::iterator where, const typename container_t::value_type& value);
78   template <class InIt>
79   void insert(InIt first, InIt last);
80 
81   KeyT MinimumThreshold;
82   size_t MinimumCount;
83   size_t MaximumCount;
84 };
85 
86 //------------------------------------------------------------------------------
87 
88 vtkStandardNewMacro(vtkDotProductSimilarity);
89 
90 //------------------------------------------------------------------------------
91 
vtkDotProductSimilarity()92 vtkDotProductSimilarity::vtkDotProductSimilarity()
93   : VectorDimension(1)
94   , MinimumThreshold(1)
95   , MinimumCount(1)
96   , MaximumCount(10)
97   , UpperDiagonal(true)
98   , Diagonal(false)
99   , LowerDiagonal(false)
100   , FirstSecond(true)
101   , SecondFirst(true)
102 {
103   this->SetNumberOfInputPorts(2);
104   this->SetNumberOfOutputPorts(1);
105 }
106 
107 //------------------------------------------------------------------------------
108 
109 vtkDotProductSimilarity::~vtkDotProductSimilarity() = default;
110 
111 //------------------------------------------------------------------------------
112 
PrintSelf(ostream & os,vtkIndent indent)113 void vtkDotProductSimilarity::PrintSelf(ostream& os, vtkIndent indent)
114 {
115   this->Superclass::PrintSelf(os, indent);
116   os << indent << "VectorDimension: " << this->VectorDimension << endl;
117   os << indent << "MinimumThreshold: " << this->MinimumThreshold << endl;
118   os << indent << "MinimumCount: " << this->MinimumCount << endl;
119   os << indent << "MaximumCount: " << this->MaximumCount << endl;
120   os << indent << "UpperDiagonal: " << this->UpperDiagonal << endl;
121   os << indent << "Diagonal: " << this->Diagonal << endl;
122   os << indent << "LowerDiagonal: " << this->LowerDiagonal << endl;
123   os << indent << "FirstSecond: " << this->FirstSecond << endl;
124   os << indent << "SecondFirst: " << this->SecondFirst << endl;
125 }
126 
FillInputPortInformation(int port,vtkInformation * info)127 int vtkDotProductSimilarity::FillInputPortInformation(int port, vtkInformation* info)
128 {
129   switch (port)
130   {
131     case 0:
132       info->Set(vtkAlgorithm::INPUT_REQUIRED_DATA_TYPE(), "vtkArrayData");
133       return 1;
134     case 1:
135       info->Set(vtkAlgorithm::INPUT_IS_OPTIONAL(), 1);
136       info->Set(vtkAlgorithm::INPUT_REQUIRED_DATA_TYPE(), "vtkArrayData");
137       return 1;
138   }
139 
140   return 0;
141 }
142 
143 //------------------------------------------------------------------------------
144 
DotProduct(vtkDenseArray<double> * input_a,vtkDenseArray<double> * input_b,const vtkIdType vector_a,const vtkIdType vector_b,const vtkIdType vector_dimension,const vtkIdType component_dimension,const vtkArrayRange range_a,const vtkArrayRange range_b)145 static double DotProduct(vtkDenseArray<double>* input_a, vtkDenseArray<double>* input_b,
146   const vtkIdType vector_a, const vtkIdType vector_b, const vtkIdType vector_dimension,
147   const vtkIdType component_dimension, const vtkArrayRange range_a, const vtkArrayRange range_b)
148 {
149   vtkArrayCoordinates coordinates_a(0, 0);
150   vtkArrayCoordinates coordinates_b(0, 0);
151 
152   coordinates_a[vector_dimension] = vector_a;
153   coordinates_b[vector_dimension] = vector_b;
154 
155   double dot_product = 0.0;
156   for (vtkIdType component = 0; component != range_a.GetSize(); ++component)
157   {
158     coordinates_a[component_dimension] = component + range_a.GetBegin();
159     coordinates_b[component_dimension] = component + range_b.GetBegin();
160     dot_product += input_a->GetValue(coordinates_a) * input_b->GetValue(coordinates_b);
161   }
162   return dot_product;
163 }
164 
RequestData(vtkInformation *,vtkInformationVector ** inputVector,vtkInformationVector * outputVector)165 int vtkDotProductSimilarity::RequestData(
166   vtkInformation*, vtkInformationVector** inputVector, vtkInformationVector* outputVector)
167 {
168   try
169   {
170     // Enforce our preconditions ...
171     vtkArrayData* const input_a = vtkArrayData::GetData(inputVector[0]);
172     if (!input_a)
173       throw std::runtime_error("Missing array data input on input port 0.");
174     if (input_a->GetNumberOfArrays() != 1)
175       throw std::runtime_error("Array data on input port 0 must contain exactly one array.");
176     vtkDenseArray<double>* const input_array_a =
177       vtkDenseArray<double>::SafeDownCast(input_a->GetArray(static_cast<vtkIdType>(0)));
178     if (!input_array_a)
179       throw std::runtime_error("Array on input port 0 must be a vtkDenseArray<double>.");
180     if (input_array_a->GetDimensions() != 2)
181       throw std::runtime_error("Array on input port 0 must be a matrix.");
182 
183     vtkArrayData* const input_b = vtkArrayData::GetData(inputVector[1]);
184     vtkDenseArray<double>* input_array_b = nullptr;
185     if (input_b)
186     {
187       if (input_b->GetNumberOfArrays() != 1)
188         throw std::runtime_error("Array data on input port 1 must contain exactly one array.");
189       input_array_b =
190         vtkDenseArray<double>::SafeDownCast(input_b->GetArray(static_cast<vtkIdType>(0)));
191       if (!input_array_b)
192         throw std::runtime_error("Array on input port 1 must be a vtkDenseArray<double>.");
193       if (input_array_b->GetDimensions() != 2)
194         throw std::runtime_error("Array on input port 1 must be a matrix.");
195     }
196 
197     const vtkIdType vector_dimension = this->VectorDimension;
198     if (vector_dimension != 0 && vector_dimension != 1)
199       throw std::runtime_error("VectorDimension must be zero or one.");
200 
201     const vtkIdType component_dimension = 1 - vector_dimension;
202 
203     const vtkArrayRange vectors_a = input_array_a->GetExtent(vector_dimension);
204     const vtkArrayRange components_a = input_array_a->GetExtent(component_dimension);
205 
206     const vtkArrayRange vectors_b =
207       input_array_b ? input_array_b->GetExtent(vector_dimension) : vtkArrayRange();
208     const vtkArrayRange components_b =
209       input_array_b ? input_array_b->GetExtent(component_dimension) : vtkArrayRange();
210 
211     if (input_array_b && (components_a.GetSize() != components_b.GetSize()))
212       throw std::runtime_error("Input array vector lengths must match.");
213 
214     // Get output arrays ...
215     vtkTable* const output = vtkTable::GetData(outputVector);
216 
217     vtkIdTypeArray* const source_array = vtkIdTypeArray::New();
218     source_array->SetName("source");
219 
220     vtkIdTypeArray* const target_array = vtkIdTypeArray::New();
221     target_array->SetName("target");
222 
223     vtkDoubleArray* const similarity_array = vtkDoubleArray::New();
224     similarity_array->SetName("similarity");
225 
226     // Okay let outside world know that I'm starting
227     double progress = 0;
228     this->InvokeEvent(vtkCommand::ProgressEvent, &progress);
229 
230     typedef threshold_multimap<double, vtkIdType> similarities_t;
231     if (input_array_b)
232     {
233       // Compare the first matrix with the second matrix ...
234       if (this->FirstSecond)
235       {
236         for (vtkIdType vector_a = vectors_a.GetBegin(); vector_a != vectors_a.GetEnd(); ++vector_a)
237         {
238           similarities_t similarities(
239             this->MinimumThreshold, this->MinimumCount, this->MaximumCount);
240 
241           for (vtkIdType vector_b = vectors_b.GetBegin(); vector_b != vectors_b.GetEnd();
242                ++vector_b)
243           {
244             // Can't use std::make_pair - see
245             // http://sahajtechstyle.blogspot.com/2007/11/whats-wrong-with-sun-studio-c.html
246             similarities.insert(std::pair<const double, vtkIdType>(
247               DotProduct(input_array_a, input_array_b, vector_a, vector_b, vector_dimension,
248                 component_dimension, components_a, components_b),
249               vector_b));
250           }
251 
252           for (similarities_t::const_iterator similarity = similarities.begin();
253                similarity != similarities.end(); ++similarity)
254           {
255             source_array->InsertNextValue(vector_a);
256             target_array->InsertNextValue(similarity->second);
257             similarity_array->InsertNextValue(similarity->first);
258           }
259         }
260       }
261       // Compare the second matrix with the first matrix ...
262       if (this->SecondFirst)
263       {
264         for (vtkIdType vector_b = vectors_b.GetBegin(); vector_b != vectors_b.GetEnd(); ++vector_b)
265         {
266           similarities_t similarities(
267             this->MinimumThreshold, this->MinimumCount, this->MaximumCount);
268 
269           for (vtkIdType vector_a = vectors_a.GetBegin(); vector_a != vectors_a.GetEnd();
270                ++vector_a)
271           {
272             // Can't use std::make_pair - see
273             // http://sahajtechstyle.blogspot.com/2007/11/whats-wrong-with-sun-studio-c.html
274             similarities.insert(std::pair<const double, vtkIdType>(
275               DotProduct(input_array_b, input_array_a, vector_b, vector_a, vector_dimension,
276                 component_dimension, components_b, components_a),
277               vector_a));
278           }
279 
280           for (similarities_t::const_iterator similarity = similarities.begin();
281                similarity != similarities.end(); ++similarity)
282           {
283             source_array->InsertNextValue(vector_b);
284             target_array->InsertNextValue(similarity->second);
285             similarity_array->InsertNextValue(similarity->first);
286           }
287         }
288       }
289     }
290     // Compare the one matrix with itself ...
291     else
292     {
293       for (vtkIdType vector_a = vectors_a.GetBegin(); vector_a != vectors_a.GetEnd(); ++vector_a)
294       {
295         similarities_t similarities(this->MinimumThreshold, this->MinimumCount, this->MaximumCount);
296 
297         for (vtkIdType vector_b = vectors_a.GetBegin(); vector_b != vectors_a.GetEnd(); ++vector_b)
298         {
299           if ((vector_b > vector_a) && !this->UpperDiagonal)
300             continue;
301 
302           if ((vector_b == vector_a) && !this->Diagonal)
303             continue;
304 
305           if ((vector_b < vector_a) && !this->LowerDiagonal)
306             continue;
307 
308           // Can't use std::make_pair - see
309           // http://sahajtechstyle.blogspot.com/2007/11/whats-wrong-with-sun-studio-c.html
310           similarities.insert(std::pair<const double, vtkIdType>(
311             DotProduct(input_array_a, input_array_a, vector_a, vector_b, vector_dimension,
312               component_dimension, components_a, components_a),
313             vector_b));
314         }
315 
316         for (similarities_t::const_iterator similarity = similarities.begin();
317              similarity != similarities.end(); ++similarity)
318         {
319           source_array->InsertNextValue(vector_a);
320           target_array->InsertNextValue(similarity->second);
321           similarity_array->InsertNextValue(similarity->first);
322         }
323       }
324     }
325 
326     output->AddColumn(source_array);
327     output->AddColumn(target_array);
328     output->AddColumn(similarity_array);
329 
330     source_array->Delete();
331     target_array->Delete();
332     similarity_array->Delete();
333 
334     return 1;
335   }
336   catch (std::exception& e)
337   {
338     vtkErrorMacro(<< "unhandled exception: " << e.what());
339     return 0;
340   }
341   catch (...)
342   {
343     vtkErrorMacro(<< "unknown exception");
344     return 0;
345   }
346 }
347