1 //:
2 // \file
3 // \brief Tool to estimate how accurately a model can fit to sets of points
4 // \author Tim Cootes
5 // Given a model and a set of shapes, estimates the statistics of the errors on each point
6 // after fitting the model to each shape.
7 // Can either perform leave-some-out experiments on training set, or apply the model
8 // to a different set.
9 
10 #include <sstream>
11 #include <iostream>
12 #include <fstream>
13 #include <string>
14 #include <iterator>
15 #include <mbl/mbl_read_props.h>
16 #include <mbl/mbl_exception.h>
17 #include <mbl/mbl_parse_colon_pairs_list.h>
18 #include <mbl/mbl_parse_int_list.h>
19 #include "vul/vul_arg.h"
20 #include "vul/vul_string.h"
21 #ifdef _MSC_VER
22 #  include "vcl_msvc_warnings.h"
23 #endif
24 #include "vsl/vsl_quick_file.h"
25 
26 #include <msm/msm_shape_model_builder.h>
27 #include <msm/msm_shape_instance.h>
28 #include <msm/msm_reflect_shape.h>
29 
30 #include <msm/msm_add_all_loaders.h>
31 #include <mbl/mbl_stats_1d.h>
32 
33 /*
34 Parameter file format:
35 <START FILE>
36 //: Aligner for shape model
37 aligner: msm_similarity_aligner
38 
39 //: Object to apply limits to parameters
40 param_limiter: msm_ellipsoid_limiter { accept_prop: 0.98 }
41 
42 // Maximum number of shape modes
43 max_modes: 99
44 
45 // Proportion of shape variation to explain
46 var_prop: 0.95
47 
48 // Optional indicies of points used to define a reference length
49 ref0: 0 ref1: 1
50 
51 // Number of chunks in n-fold cross validation
52 n_chunks: 10
53 
54 //: Define renumbering required under reflection
55 //  If defined, a reflected version of each shape is included in build
56 reflection_symmetry: { 7 6 5 4 3 2 1 0 }
57 
58 //: When true, only use reflection. When false, use both reflection and original.
59 only_reflect: false
60 
61 
62 image_dir: /home/images/
63 points_dir: /home/points/
64 images: {
65   image1.pts : image1.jpg
66   image2.pts : image2.jpg
67 }
68 
69 <END FILE>
70 */
71 
print_usage()72 void print_usage()
73 {
74   std::cout << "msm_estimate_residuals -p param_file -t test_points_list.txt\n"
75            << "Builds the shape model from the supplied data, tests on shapes in test_points_list.txt\n"
76            << "If no test_points_list.txt provided, performs leave-some-out tests on training data."
77            << std::endl;
78 
79   vul_arg_display_usage_and_exit();
80 }
81 
82 //: Structure to hold parameters
83 struct tool_params
84 {
85   //: Aligner for shape model
86   std::unique_ptr<msm_aligner> aligner;
87 
88   //: Object to apply limits to parameters
89   std::unique_ptr<msm_param_limiter> limiter;
90 
91   //: Maximum number of shape modes
92   unsigned max_modes;
93 
94   //: Proportion of shape variation to explain
95   double var_prop;
96 
97   //: Ref. point indices used to define reference length.
98   unsigned ref0,ref1;
99 
100   //: Number of chunks in n-fold cross validation
101   unsigned n_chunks;
102 
103   //: Define renumbering required under reflection
104   //  If defined, a reflected version of each shape is included in build
105   std::vector<unsigned> reflection_symmetry;
106 
107   //: When true, only use reflection. When false, use both reflection and original.
108   bool only_reflect;
109 
110 
111   //: Directory containing images
112   std::string image_dir;
113 
114   //: Directory containing points
115   std::string points_dir;
116 
117   //: List of image names
118   std::vector<std::string> image_names;
119 
120   //: List of points file names
121   std::vector<std::string> points_names;
122 
123   //: Parse named text file to read in data
124   //  Throws a mbl_exception_parse_error if fails
125   void read_from_file(const std::string& path);
126 };
127 
128 //: Parse named text file to read in data
129 //  Throws a mbl_exception_parse_error if fails
read_from_file(const std::string & path)130 void tool_params::read_from_file(const std::string& path)
131 {
132   std::ifstream ifs(path.c_str());
133   if (!ifs)
134   {
135     std::string error_msg = "Failed to open file: "+path;
136     throw (mbl_exception_parse_error(error_msg));
137   }
138 
139   mbl_read_props_type props = mbl_read_props_ws(ifs);
140 
141   max_modes=vul_string_atoi(props.get_optional_property("max_modes","99"));
142   var_prop=vul_string_atof(props.get_optional_property("var_prop","0.95"));
143   image_dir=props.get_optional_property("image_dir","./");
144   points_dir=props.get_optional_property("points_dir","./");
145 
146   ref0=vul_string_atoi(props.get_optional_property("ref0","0"));
147   ref1=vul_string_atoi(props.get_optional_property("ref1","0"));
148   n_chunks=vul_string_atoi(props.get_optional_property("n_chunks","10"));
149 
150   {
151     std::string aligner_str
152        = props.get_required_property("aligner");
153     std::stringstream ss(aligner_str);
154     aligner = msm_aligner::create_from_stream(ss);
155   }
156 
157   {
158     std::string limiter_str
159        = props.get_optional_property("param_limiter",
160                                      "msm_ellipsoid_limiter { accept_prop: 0.98 }");
161     std::stringstream ss(limiter_str);
162     limiter = msm_param_limiter::create_from_stream(ss);
163   }
164 
165   std::string ref_sym_str=props.get_optional_property("reflection_symmetry","-");
166   reflection_symmetry.resize(0);
167   if (ref_sym_str!="-")
168   {
169     std::stringstream ss(ref_sym_str);
170     mbl_parse_int_list(ss, std::back_inserter(reflection_symmetry),
171                        unsigned());
172   }
173 
174   only_reflect=vul_string_to_bool(props.get_optional_property("only_reflect","false"));
175 
176 
177   mbl_parse_colon_pairs_list(props.get_required_property("images"),
178                              points_names,image_names);
179 
180   // Don't look for unused props so can use a single common parameter file.
181 }
182 
183 //: Structure to hold parameters of image list file
184 struct image_list_params
185 {
186   //: Directory containing images
187   std::string image_dir;
188 
189   //: Directory containing points
190   std::string points_dir;
191 
192   //: List of image names
193   std::vector<std::string> image_names;
194 
195   //: List of points file names
196   std::vector<std::string> points_names;
197 
198   //: Parse named text file to read in data
199   //  Throws a mbl_exception_parse_error if fails
200   void read_from_file(const std::string& path);
201 };
202 
203 //: Parse named text file to read in data
204 //  Throws a mbl_exception_parse_error if fails
read_from_file(const std::string & path)205 void image_list_params::read_from_file(const std::string& path)
206 {
207   std::ifstream ifs(path.c_str());
208   if (!ifs)
209   {
210     std::string error_msg = "Failed to open file: "+path;
211     throw (mbl_exception_parse_error(error_msg));
212   }
213 
214   mbl_read_props_type props = mbl_read_props_ws(ifs);
215 
216   image_dir=props.get_optional_property("image_dir","./");
217   points_dir=props.get_optional_property("points_dir","./");
218 
219   mbl_parse_colon_pairs_list(props.get_required_property("images"),
220                              points_names,image_names);
221 
222   // Don't look for unused props so can use a single common parameter file.
223 }
224 
225 
226 
227 struct msm_test_stats {
228   //: Stats of mean distance in world frame
229   mbl_stats_1d world_d_stats;
230   //: Stats of mean distance relative to a reference length (%)
231   mbl_stats_1d rel_d_stats;
232   //: Stats of mean distance in model reference frame
233   mbl_stats_1d ref_d_stats;
234 
235   //: Stats of residual x in reference frame
236   mbl_stats_1d ref_x_stats;
237   //: Stats of residual y in reference frame
238   mbl_stats_1d ref_y_stats;
239 };
240 
calc_point_distances(const msm_points & points1,const msm_points & points2,vnl_vector<double> & d)241 void calc_point_distances(const msm_points& points1, // labelled
242                             const msm_points& points2,    // predicted
243                             vnl_vector<double>& d)
244 {
245   d.set_size(points1.size());
246   d.fill(0.0);
247   for (unsigned i=0;i<points1.size();++i)
248     d[i]=(points1[i]-points2[i]).length();
249 }
250 
251 
test_model(const msm_shape_model & shape_model,int n_modes,const std::vector<msm_points> & points,unsigned ref0,unsigned ref1,msm_test_stats & stats)252 void test_model(const msm_shape_model& shape_model, int n_modes,
253                 const std::vector<msm_points>& points,
254                 unsigned ref0, unsigned ref1,
255                 msm_test_stats& stats)
256 {
257   msm_shape_instance sm_inst(shape_model);
258 
259   if (n_modes>=0)
260   {
261     // Arrange to use n_modes modes
262     vnl_vector<double> b(n_modes,0.0);
263     sm_inst.set_params(b);
264   }
265 
266   vnl_vector<double> d,inv_pose;
267   const msm_aligner& aligner = shape_model.aligner();
268   msm_points points_in_ref, dpoints;
269 
270   for (const auto & point : points)
271   {
272     sm_inst.fit_to_points(point);
273 
274     // Currently just compute overall distance
275     // Eventually need to project into model frame to compute individual errors
276     calc_point_distances(sm_inst.points(),point,d);
277     stats.world_d_stats.obs(d.mean());
278     if (ref0!=ref1)
279     {
280       double ref_d = (point[ref0]-point[ref1]).length();
281       stats.rel_d_stats.obs(100*d.mean()/ref_d);
282     }
283 
284     // Evaluate in the reference frame
285     inv_pose=aligner.inverse(sm_inst.pose());
286     aligner.apply_transform(point,inv_pose,points_in_ref);
287     calc_point_distances(sm_inst.model_points(),points_in_ref,d);
288     stats.ref_d_stats.obs(d.mean());
289 
290     dpoints.vector()=points_in_ref.vector()-sm_inst.model_points().vector();
291     for (unsigned j=0;j<dpoints.size();++j)
292     {
293       stats.ref_x_stats.obs(dpoints[j].x());
294       stats.ref_y_stats.obs(dpoints[j].y());
295     }
296   }
297 }
298 
299 
300 // Perform leave-some-out experiments, chopping data into n_chunks chunks
leave_some_out_tests(msm_shape_model_builder & builder,const std::vector<msm_points> & points,unsigned ref0,unsigned ref1,unsigned n_chunks,std::vector<msm_test_stats> & test_stats)301 void leave_some_out_tests(msm_shape_model_builder& builder,
302                           const std::vector<msm_points>& points,
303                           unsigned ref0, unsigned ref1,
304                           unsigned n_chunks,
305                           std::vector<msm_test_stats>& test_stats)
306 {
307   // Arrange to miss out consecutive examples.
308   double chunk_size=double(points.size())/n_chunks;
309   if (chunk_size<1) return;
310 
311   for (unsigned ic=0;ic<n_chunks;++ic)
312   {
313     std::vector<msm_points> trn_set,test_set;
314     for (unsigned i=0;i<points.size();++i)
315     {
316       if (unsigned(i/chunk_size)==ic) test_set.push_back(points[i]);
317       else                            trn_set.push_back(points[i]);
318     }
319     msm_shape_model shape_model;
320     builder.build_model(trn_set,shape_model);
321 
322     for (unsigned nm=0;nm<=shape_model.n_modes();++nm)
323     {
324       test_model(shape_model,nm,test_set,ref0,ref1,test_stats[nm]);
325     }
326   }
327 }
328 
main(int argc,char ** argv)329 int main(int argc, char** argv)
330 {
331   vul_arg<std::string> param_path("-p","Parameter filename");
332   vul_arg<std::string> test_list_path("-t","List of points files to test on");
333   vul_arg<std::string> output_path("-o","Path for residual statistics output");
334   vul_arg_parse(argc,argv);
335 
336   msm_add_all_loaders();
337 
338   if (param_path().empty()) {
339     print_usage();
340     return 0;
341   }
342 
343   tool_params params;
344   try
345   {
346     params.read_from_file(param_path());
347   }
348   catch (mbl_exception_parse_error& e)
349   {
350     std::cerr<<"Error: "<<e.what()<<'\n';
351     return 1;
352   }
353 
354   image_list_params image_list;
355   if (!test_list_path().empty()) {
356     try
357     {
358       image_list.read_from_file(test_list_path());
359     }
360     catch (mbl_exception_parse_error& e)
361     {
362       std::cerr<<"Error: "<<e.what()<<'\n';
363       return 1;
364     }
365   }
366 
367   msm_shape_model_builder builder;
368 
369   builder.set_aligner(*params.aligner);
370   builder.set_param_limiter(*params.limiter);
371   builder.set_mode_choice(0,params.max_modes,params.var_prop);
372 
373   std::vector<msm_points> shapes(params.points_names.size());
374   msm_load_shapes(params.points_dir,params.points_names,shapes);
375 
376   if (!params.reflection_symmetry.empty()) {
377     // Use reflections
378     msm_points ref_points;
379     unsigned n=shapes.size();
380     for (unsigned i=0;i<n;++i)
381     {
382       msm_reflect_shape_along_x(shapes[i],params.reflection_symmetry,
383                                 ref_points,shapes[i].cog().x());
384       if (params.only_reflect) shapes[i]=ref_points;
385       else                     shapes.push_back(ref_points);
386     }
387   }
388 
389   std::vector<msm_test_stats> test_stats(params.max_modes+1);
390 
391   if (!test_list_path().empty()) {
392     std::cout<<"Testing on "<<image_list.points_names.size()<<" examples from "<<test_list_path()<<std::endl;
393     msm_shape_model shape_model;
394     builder.build_model(shapes,shape_model);
395     std::cout<<"Shape Model: "<<shape_model<<std::endl;
396 
397     std::vector<msm_points> test_shapes(image_list.points_names.size());
398     msm_load_shapes(image_list.points_dir,image_list.points_names,test_shapes);
399     // Test with differing numbers of modes
400     for (unsigned nm=0;nm<=shape_model.n_modes();++nm)
401     {
402       test_model(shape_model,nm,test_shapes,params.ref0,params.ref1,test_stats[nm]);
403     }
404   } else {
405     std::cout<<"Performing "<<params.n_chunks<<"-fold cross validation on "<<shapes.size()<<" examples from training set."<<std::endl;
406     leave_some_out_tests(builder,shapes,params.ref0,params.ref1,params.n_chunks,test_stats);
407   }
408 
409   std::string out_path="residual_stats.txt";
410   if (!output_path().empty())
411     out_path = output_path();
412   std::ofstream ofs(out_path.c_str());
413   if (!ofs)
414   {
415     std::cerr<<"Unable to open "<<out_path<<" for results."<<std::endl;
416     return 2;
417   }
418 
419   ofs<<"NModes WorldMean WorldSD   RefMean ";
420   if (params.ref0!=params.ref1) ofs<<"     RelMean(%)";
421   ofs<<"RefXSD RefYSD ";
422   ofs<<std::endl;
423   for (unsigned nm=0;nm<test_stats.size();++nm)
424   {
425     if (test_stats[nm].world_d_stats.nObs()==0) continue;
426     ofs<<nm<<"      "<<test_stats[nm].world_d_stats.mean();
427     ofs<<" "<<test_stats[nm].world_d_stats.sd();
428     ofs<<"    "<<test_stats[nm].ref_d_stats.mean();
429     if (params.ref0!=params.ref1)
430       ofs<<"    "<<test_stats[nm].rel_d_stats.mean();
431     ofs<<"    "<<test_stats[nm].ref_x_stats.sd();
432     ofs<<"    "<<test_stats[nm].ref_y_stats.sd();
433     ofs<<std::endl;
434   }
435   ofs.close();
436   std::cout<<"Saved residual statistics to "<<out_path<<std::endl;
437 
438   return 0;
439 }
440