1 #include "MUQ/Modeling/Flann/FlannCache.h"
2 
3 using namespace muq::Modeling;
4 
FlannCache(std::shared_ptr<ModPiece> function)5 FlannCache::FlannCache(std::shared_ptr<ModPiece> function) : ModPiece(function->inputSizes, function->outputSizes), // can only have one input and output
6 							     function(function),
7 							     kdTree(std::make_shared<DynamicKDTreeAdaptor<>>(function->inputSizes(0))) {
8 
9   // the target function can only have one input/output
10   assert(function->numInputs==1);
11   assert(function->numOutputs==1);
12 	centroid = Eigen::VectorXd::Zero(inputSizes(0));
13 }
14 
~FlannCache()15 FlannCache::~FlannCache() {}
16 
EvaluateImpl(ref_vector<Eigen::VectorXd> const & inputs)17 void FlannCache::EvaluateImpl(ref_vector<Eigen::VectorXd> const& inputs) {
18     int cacheId = InCache(inputs.at(0));
19     outputs.resize(1);
20     if(cacheId < 0){
21       Add(inputs.at(0));
22       outputs.at(0) = outputCache.at(outputCache.size()-1);
23     }else{
24       outputs.at(0) = outputCache.at(cacheId);
25     }
26 }
27 
InCache(Eigen::VectorXd const & input) const28 int FlannCache::InCache(Eigen::VectorXd const& input) const {
29   if( Size()>0 ) { // if there are points in the cache
30     std::vector<size_t> indices;
31     std::vector<double> squaredDists;
32     std::tie(indices, squaredDists) = kdTree->query(input, 1);
33 
34     if(squaredDists.at(0)<std::numeric_limits<double>::epsilon()){
35       return indices.at(0);
36     }
37   }
38 
39   // the cache is either empty or none of the points in a small radius are exactly the point we care about
40   return -1;
41 }
42 
Add(Eigen::VectorXd const & newPt)43 Eigen::VectorXd FlannCache::Add(Eigen::VectorXd const& newPt) {
44   // evaluate the function
45 	assert(function);
46   const Eigen::VectorXd& newOutput = function->Evaluate(newPt).at(0);
47 
48   // add the new point
49   Add(newPt, newOutput);
50 
51   // return the result
52   return newOutput;
53 }
54 
Add(Eigen::VectorXd const & input,Eigen::VectorXd const & result)55 unsigned int FlannCache::Add(Eigen::VectorXd const& input, Eigen::VectorXd const& result) {
56   assert(input.size()==function->inputSizes(0));
57   assert(result.size()==function->outputSizes(0));
58 
59 	int cacheId = InCache(input);
60 
61 	if(cacheId<0){
62 	  kdTree->add(input);
63 	  outputCache.push_back(result);
64 
65 	  assert(kdTree->m_data.size()==outputCache.size());
66 
67 		UpdateCentroid(input);
68 
69 		return outputCache.size()-1;
70 
71 	}else{
72 		return cacheId;
73 	}
74 }
75 
Remove(Eigen::VectorXd const & input)76 void FlannCache::Remove(Eigen::VectorXd const& input) {
77   // get the index of the point
78   const int id = InCache(input);
79 
80   // the point is not in the cache ... nothing to remove
81   if( id<0 ) { return; }
82 
83   // remove from output
84   outputCache.erase(outputCache.begin()+id);
85 
86   // remove from input
87   kdTree->m_data.erase(kdTree->m_data.begin()+id);
88   kdTree->UpdateIndex();
89 }
90 
NearestNeighborIndex(Eigen::VectorXd const & point) const91 size_t FlannCache::NearestNeighborIndex(Eigen::VectorXd const& point) const {
92   // make sure we have enough
93   assert(1<=Size());
94 
95   std::vector<size_t> indices;
96   std::vector<double> squaredDists;
97   std::tie(indices, squaredDists) = kdTree->query(point, 1);
98   assert(indices.size()==1);
99   assert(squaredDists.size()==1);
100 
101   return indices[0];
102 }
103 
NearestNeighbors(Eigen::VectorXd const & point,unsigned int const k,std::vector<Eigen::VectorXd> & neighbors,std::vector<Eigen::VectorXd> & result) const104 void FlannCache::NearestNeighbors(Eigen::VectorXd const& point,
105                                   unsigned int const k,
106                                   std::vector<Eigen::VectorXd>& neighbors,
107                                   std::vector<Eigen::VectorXd>& result) const {
108   // make sure we have enough
109   assert(k<=Size());
110 
111   std::vector<size_t> indices;
112   std::vector<double> squaredDists;
113   std::tie(indices, squaredDists) = kdTree->query(point, k);
114   assert(indices.size()==k);
115   assert(squaredDists.size()==k);
116 
117   neighbors.resize(k);
118   result.resize(k);
119   for( unsigned int i=0; i<k; ++i ){
120     neighbors.at(i) = kdTree->m_data.at(indices[i]);
121     result.at(i) = outputCache.at(indices[i]);
122   }
123 }
124 
NearestNeighbors(Eigen::VectorXd const & point,unsigned int const k,std::vector<Eigen::VectorXd> & neighbors) const125 void FlannCache::NearestNeighbors(Eigen::VectorXd const& point,
126                                   unsigned int const k,
127                                   std::vector<Eigen::VectorXd>& neighbors) const {
128   // make sure we have enough
129   assert(k<=Size());
130 
131   std::vector<size_t> indices;
132   std::vector<double> squaredDists;
133   std::tie(indices, squaredDists) = kdTree->query(point, k);
134   assert(indices.size()==k);
135   assert(squaredDists.size()==k);
136 
137   neighbors.resize(k);
138   for( unsigned int i=0; i<k; ++i ){ neighbors.at(i) = kdTree->m_data.at(indices[i]); }
139 }
140 
Size() const141 unsigned int FlannCache::Size() const {
142   // these two numbers should be the same unless we check the size after adding the input but before the model finishings running
143   return std::min(kdTree->m_data.size(), outputCache.size());
144 }
145 
Add(std::vector<Eigen::VectorXd> const & inputs)146 std::vector<Eigen::VectorXd> FlannCache::Add(std::vector<Eigen::VectorXd> const& inputs) {
147   std::vector<Eigen::VectorXd> results(inputs.size());
148 
149   for( unsigned int i=0; i<inputs.size(); ++i ) {
150     // see if the point is already there
151     const int index = InCache(inputs[i]);
152 
153     // add the point if is not already there
154     results[i] = index<0? Add(inputs[i]) : outputCache.at(index);
155 
156     // make sure it got added
157     assert(InCache(inputs[i])>=0);
158   }
159 
160   assert(kdTree->m_data.size()==outputCache.size());
161 
162   return results;
163 }
164 
Add(std::vector<Eigen::VectorXd> const & inputs,std::vector<Eigen::VectorXd> const & results)165 void FlannCache::Add(std::vector<Eigen::VectorXd> const& inputs, std::vector<Eigen::VectorXd> const& results) {
166   assert(inputs.size()==results.size());
167 
168   for( unsigned int i=0; i<inputs.size(); ++i ) {
169     // add the point to cache (with result)
170     Add(inputs[i], results[i]);
171 
172     // make sure it got added
173     assert(InCache(inputs[i])>=0);
174   }
175 
176   assert(kdTree->m_data.size()==outputCache.size());
177 }
178 
at(unsigned int const index) const179 const Eigen::VectorXd FlannCache::at(unsigned int const index) const {
180   assert(index<kdTree->m_data.size());
181   return kdTree->m_data[index];
182 }
183 
at(unsigned int const index)184 Eigen::VectorXd FlannCache::at(unsigned int const index) {
185   assert(index<kdTree->m_data.size());
186   return kdTree->m_data[index];
187 }
188 
OutputValue(unsigned int index) const189 Eigen::VectorXd const& FlannCache::OutputValue(unsigned int index) const{
190 	return outputCache.at(index);
191 }
192 
UpdateCentroid(Eigen::VectorXd const & point)193 void FlannCache::UpdateCentroid(Eigen::VectorXd const& point) {
194 	centroid = ((double)(Size()-1)*centroid+point)/(double)Size();
195 }
196 
Centroid() const197 Eigen::VectorXd FlannCache::Centroid() const { return centroid; }
198 
Function() const199 std::shared_ptr<ModPiece> FlannCache::Function() const { return function; }
200