1 #include "MUQ/Modeling/Flann/FlannCache.h"
2
3 using namespace muq::Modeling;
4
FlannCache(std::shared_ptr<ModPiece> function)5 FlannCache::FlannCache(std::shared_ptr<ModPiece> function) : ModPiece(function->inputSizes, function->outputSizes), // can only have one input and output
6 function(function),
7 kdTree(std::make_shared<DynamicKDTreeAdaptor<>>(function->inputSizes(0))) {
8
9 // the target function can only have one input/output
10 assert(function->numInputs==1);
11 assert(function->numOutputs==1);
12 centroid = Eigen::VectorXd::Zero(inputSizes(0));
13 }
14
~FlannCache()15 FlannCache::~FlannCache() {}
16
EvaluateImpl(ref_vector<Eigen::VectorXd> const & inputs)17 void FlannCache::EvaluateImpl(ref_vector<Eigen::VectorXd> const& inputs) {
18 int cacheId = InCache(inputs.at(0));
19 outputs.resize(1);
20 if(cacheId < 0){
21 Add(inputs.at(0));
22 outputs.at(0) = outputCache.at(outputCache.size()-1);
23 }else{
24 outputs.at(0) = outputCache.at(cacheId);
25 }
26 }
27
InCache(Eigen::VectorXd const & input) const28 int FlannCache::InCache(Eigen::VectorXd const& input) const {
29 if( Size()>0 ) { // if there are points in the cache
30 std::vector<size_t> indices;
31 std::vector<double> squaredDists;
32 std::tie(indices, squaredDists) = kdTree->query(input, 1);
33
34 if(squaredDists.at(0)<std::numeric_limits<double>::epsilon()){
35 return indices.at(0);
36 }
37 }
38
39 // the cache is either empty or none of the points in a small radius are exactly the point we care about
40 return -1;
41 }
42
Add(Eigen::VectorXd const & newPt)43 Eigen::VectorXd FlannCache::Add(Eigen::VectorXd const& newPt) {
44 // evaluate the function
45 assert(function);
46 const Eigen::VectorXd& newOutput = function->Evaluate(newPt).at(0);
47
48 // add the new point
49 Add(newPt, newOutput);
50
51 // return the result
52 return newOutput;
53 }
54
Add(Eigen::VectorXd const & input,Eigen::VectorXd const & result)55 unsigned int FlannCache::Add(Eigen::VectorXd const& input, Eigen::VectorXd const& result) {
56 assert(input.size()==function->inputSizes(0));
57 assert(result.size()==function->outputSizes(0));
58
59 int cacheId = InCache(input);
60
61 if(cacheId<0){
62 kdTree->add(input);
63 outputCache.push_back(result);
64
65 assert(kdTree->m_data.size()==outputCache.size());
66
67 UpdateCentroid(input);
68
69 return outputCache.size()-1;
70
71 }else{
72 return cacheId;
73 }
74 }
75
Remove(Eigen::VectorXd const & input)76 void FlannCache::Remove(Eigen::VectorXd const& input) {
77 // get the index of the point
78 const int id = InCache(input);
79
80 // the point is not in the cache ... nothing to remove
81 if( id<0 ) { return; }
82
83 // remove from output
84 outputCache.erase(outputCache.begin()+id);
85
86 // remove from input
87 kdTree->m_data.erase(kdTree->m_data.begin()+id);
88 kdTree->UpdateIndex();
89 }
90
NearestNeighborIndex(Eigen::VectorXd const & point) const91 size_t FlannCache::NearestNeighborIndex(Eigen::VectorXd const& point) const {
92 // make sure we have enough
93 assert(1<=Size());
94
95 std::vector<size_t> indices;
96 std::vector<double> squaredDists;
97 std::tie(indices, squaredDists) = kdTree->query(point, 1);
98 assert(indices.size()==1);
99 assert(squaredDists.size()==1);
100
101 return indices[0];
102 }
103
NearestNeighbors(Eigen::VectorXd const & point,unsigned int const k,std::vector<Eigen::VectorXd> & neighbors,std::vector<Eigen::VectorXd> & result) const104 void FlannCache::NearestNeighbors(Eigen::VectorXd const& point,
105 unsigned int const k,
106 std::vector<Eigen::VectorXd>& neighbors,
107 std::vector<Eigen::VectorXd>& result) const {
108 // make sure we have enough
109 assert(k<=Size());
110
111 std::vector<size_t> indices;
112 std::vector<double> squaredDists;
113 std::tie(indices, squaredDists) = kdTree->query(point, k);
114 assert(indices.size()==k);
115 assert(squaredDists.size()==k);
116
117 neighbors.resize(k);
118 result.resize(k);
119 for( unsigned int i=0; i<k; ++i ){
120 neighbors.at(i) = kdTree->m_data.at(indices[i]);
121 result.at(i) = outputCache.at(indices[i]);
122 }
123 }
124
NearestNeighbors(Eigen::VectorXd const & point,unsigned int const k,std::vector<Eigen::VectorXd> & neighbors) const125 void FlannCache::NearestNeighbors(Eigen::VectorXd const& point,
126 unsigned int const k,
127 std::vector<Eigen::VectorXd>& neighbors) const {
128 // make sure we have enough
129 assert(k<=Size());
130
131 std::vector<size_t> indices;
132 std::vector<double> squaredDists;
133 std::tie(indices, squaredDists) = kdTree->query(point, k);
134 assert(indices.size()==k);
135 assert(squaredDists.size()==k);
136
137 neighbors.resize(k);
138 for( unsigned int i=0; i<k; ++i ){ neighbors.at(i) = kdTree->m_data.at(indices[i]); }
139 }
140
Size() const141 unsigned int FlannCache::Size() const {
142 // these two numbers should be the same unless we check the size after adding the input but before the model finishings running
143 return std::min(kdTree->m_data.size(), outputCache.size());
144 }
145
Add(std::vector<Eigen::VectorXd> const & inputs)146 std::vector<Eigen::VectorXd> FlannCache::Add(std::vector<Eigen::VectorXd> const& inputs) {
147 std::vector<Eigen::VectorXd> results(inputs.size());
148
149 for( unsigned int i=0; i<inputs.size(); ++i ) {
150 // see if the point is already there
151 const int index = InCache(inputs[i]);
152
153 // add the point if is not already there
154 results[i] = index<0? Add(inputs[i]) : outputCache.at(index);
155
156 // make sure it got added
157 assert(InCache(inputs[i])>=0);
158 }
159
160 assert(kdTree->m_data.size()==outputCache.size());
161
162 return results;
163 }
164
Add(std::vector<Eigen::VectorXd> const & inputs,std::vector<Eigen::VectorXd> const & results)165 void FlannCache::Add(std::vector<Eigen::VectorXd> const& inputs, std::vector<Eigen::VectorXd> const& results) {
166 assert(inputs.size()==results.size());
167
168 for( unsigned int i=0; i<inputs.size(); ++i ) {
169 // add the point to cache (with result)
170 Add(inputs[i], results[i]);
171
172 // make sure it got added
173 assert(InCache(inputs[i])>=0);
174 }
175
176 assert(kdTree->m_data.size()==outputCache.size());
177 }
178
at(unsigned int const index) const179 const Eigen::VectorXd FlannCache::at(unsigned int const index) const {
180 assert(index<kdTree->m_data.size());
181 return kdTree->m_data[index];
182 }
183
at(unsigned int const index)184 Eigen::VectorXd FlannCache::at(unsigned int const index) {
185 assert(index<kdTree->m_data.size());
186 return kdTree->m_data[index];
187 }
188
OutputValue(unsigned int index) const189 Eigen::VectorXd const& FlannCache::OutputValue(unsigned int index) const{
190 return outputCache.at(index);
191 }
192
UpdateCentroid(Eigen::VectorXd const & point)193 void FlannCache::UpdateCentroid(Eigen::VectorXd const& point) {
194 centroid = ((double)(Size()-1)*centroid+point)/(double)Size();
195 }
196
Centroid() const197 Eigen::VectorXd FlannCache::Centroid() const { return centroid; }
198
Function() const199 std::shared_ptr<ModPiece> FlannCache::Function() const { return function; }
200