1 // Copyright (C) 2011  Davis E. King (davis@dlib.net)
2 // License: Boost Software License   See LICENSE.txt for the full license.
3 
4 
5 #include <dlib/statistics.h>
6 #include <sstream>
7 #include <string>
8 #include <cstdlib>
9 #include <ctime>
10 #include "tester.h"
11 #include <dlib/pixel.h>
12 #include <dlib/svm_threaded.h>
13 #include <dlib/array.h>
14 #include <dlib/set_utils.h>
15 #include <dlib/array2d.h>
16 #include <dlib/image_keypoint.h>
17 #include <dlib/image_processing.h>
18 #include <dlib/image_transforms.h>
19 
20 namespace
21 {
22     using namespace test;
23     using namespace dlib;
24     using namespace std;
25 
26     logger dlog("test.object_detector");
27 
28 // ----------------------------------------------------------------------------------------
29 
30     struct funny_image
31     {
32         array2d<unsigned char> img;
nr__anon62e9091a0111::funny_image33         long nr() const { return img.nr(); }
nc__anon62e9091a0111::funny_image34         long nc() const { return img.nc(); }
35     };
36 
swap(funny_image & a,funny_image & b)37     void swap(funny_image& a, funny_image& b)
38     {
39         a.img.swap(b.img);
40     }
41 
42 // ----------------------------------------------------------------------------------------
43 
44     template <
45         typename image_array_type,
46         typename detector_type
47         >
validate_some_object_detector_stuff(const image_array_type & images,detector_type & detector,double eps=1e-10)48     void validate_some_object_detector_stuff (
49         const image_array_type& images,
50         detector_type& detector,
51         double eps = 1e-10
52     )
53     {
54         for (unsigned long i = 0; i < images.size(); ++i)
55         {
56             std::vector<rectangle> dets = detector(images[i]);
57             std::vector<std::pair<double,rectangle> > dets2;
58 
59             detector(images[i], dets2);
60 
61             matrix<double,0,1> psi(detector.get_w().size());
62             matrix<double,0,1> psi2(detector.get_w().size());
63             const double thresh = detector.get_w()(detector.get_w().size()-1);
64 
65             DLIB_TEST(dets.size() == dets2.size());
66             for (unsigned long j = 0; j < dets.size(); ++j)
67             {
68                 DLIB_TEST(dets[j] == dets2[j].second);
69 
70                 const full_object_detection fdet = detector.get_scanner().get_full_object_detection(dets[j], detector.get_w());
71                 psi = 0;
72                 detector.get_scanner().get_feature_vector(fdet, psi);
73 
74                 double check_score = dot(psi,detector.get_w()) - thresh;
75                 DLIB_TEST_MSG(std::abs(check_score - dets2[j].first) < eps, std::abs(check_score - dets2[j].first) << "  check_score: "<< check_score);
76             }
77 
78         }
79     }
80 
81 // ----------------------------------------------------------------------------------------
82 
83     class very_simple_feature_extractor : noncopyable
84     {
85         /*!
86         WHAT THIS OBJECT REPRESENTS
87             This object is a feature extractor which goes to every pixel in an image and
88             produces a 32 dimensional feature vector.  This vector is an indicator vector
89             which records the pattern of pixel values in a 4-connected region.  So it should
90             be able to distinguish basic things like whether or not a location falls on the
91             corner of a white box, on an edge, in the middle, etc.
92 
93 
94             Note that this object also implements the interface defined in dlib/image_keypoint/hashed_feature_image_abstract.h.
95             This means all the member functions in this object are supposed to behave as
96             described in the hashed_feature_image specification.  So when you define your own
97             feature extractor objects you should probably refer yourself to that documentation
98             in addition to reading this example program.
99         !*/
100 
101 
102     public:
103 
load(const funny_image & img_)104         inline void load (
105             const funny_image& img_
106         )
107         {
108             const array2d<unsigned char>& img = img_.img;
109 
110             feat_image.set_size(img.nr(), img.nc());
111             assign_all_pixels(feat_image,0);
112             for (long r = 1; r+1 < img.nr(); ++r)
113             {
114                 for (long c = 1; c+1 < img.nc(); ++c)
115                 {
116                     unsigned char f = 0;
117                     if (img[r][c])   f |= 0x1;
118                     if (img[r][c+1]) f |= 0x2;
119                     if (img[r][c-1]) f |= 0x4;
120                     if (img[r+1][c]) f |= 0x8;
121                     if (img[r-1][c]) f |= 0x10;
122 
123                     // Store the code value for the pattern of pixel values in the 4-connected
124                     // neighborhood around this row and column.
125                     feat_image[r][c] = f;
126                 }
127             }
128         }
129 
load(const array2d<unsigned char> & img)130         inline void load (
131             const array2d<unsigned char>& img
132         )
133         {
134             feat_image.set_size(img.nr(), img.nc());
135             assign_all_pixels(feat_image,0);
136             for (long r = 1; r+1 < img.nr(); ++r)
137             {
138                 for (long c = 1; c+1 < img.nc(); ++c)
139                 {
140                     unsigned char f = 0;
141                     if (img[r][c])   f |= 0x1;
142                     if (img[r][c+1]) f |= 0x2;
143                     if (img[r][c-1]) f |= 0x4;
144                     if (img[r+1][c]) f |= 0x8;
145                     if (img[r-1][c]) f |= 0x10;
146 
147                     // Store the code value for the pattern of pixel values in the 4-connected
148                     // neighborhood around this row and column.
149                     feat_image[r][c] = f;
150                 }
151             }
152         }
153 
size() const154         inline size_t size () const { return feat_image.size(); }
nr() const155         inline long nr () const { return feat_image.nr(); }
nc() const156         inline long nc () const { return feat_image.nc(); }
157 
get_num_dimensions() const158         inline long get_num_dimensions (
159         ) const
160         {
161             // Return the dimensionality of the vectors produced by operator()
162             return 32;
163         }
164 
165         typedef std::vector<std::pair<unsigned int,double> > descriptor_type;
166 
operator ()(long row,long col) const167         inline const descriptor_type& operator() (
168             long row,
169             long col
170         ) const
171             /*!
172                 requires
173                     - 0 <= row < nr()
174             - 0 <= col < nc()
175                 ensures
176                     - returns a sparse vector which describes the image at the given row and column.
177                       In particular, this is a vector that is 0 everywhere except for one element.
178             !*/
179         {
180             feat.clear();
181             const unsigned long only_nonzero_element_index = feat_image[row][col];
182             feat.push_back(make_pair(only_nonzero_element_index,1.0));
183             return feat;
184         }
185 
186         // This block of functions is meant to provide a way to map between the row/col space taken by
187         // this object's operator() function and the images supplied to load().  In this example it's trivial.
188         // However, in general, you might create feature extractors which don't perform extraction at every
189         // possible image location (e.g. the hog_image) and thus result in some more complex mapping.
get_block_rect(long row,long col) const190         inline const rectangle get_block_rect       ( long row, long col) const { return centered_rect(col,row,3,3); }
image_to_feat_space(const point & p) const191         inline const point image_to_feat_space      ( const point& p) const { return p; }
image_to_feat_space(const rectangle & rect) const192         inline const rectangle image_to_feat_space  ( const rectangle& rect) const { return rect; }
feat_to_image_space(const point & p) const193         inline const point feat_to_image_space      ( const point& p) const { return p; }
feat_to_image_space(const rectangle & rect) const194         inline const rectangle feat_to_image_space  ( const rectangle& rect) const { return rect; }
195 
serialize(const very_simple_feature_extractor & item,std::ostream & out)196         inline friend void serialize   ( const very_simple_feature_extractor& item, std::ostream& out)  { serialize(item.feat_image, out); }
deserialize(very_simple_feature_extractor & item,std::istream & in)197         inline friend void deserialize ( very_simple_feature_extractor& item, std::istream& in ) { deserialize(item.feat_image, in); }
198 
copy_configuration(const very_simple_feature_extractor &)199         void copy_configuration ( const very_simple_feature_extractor& ){}
200 
201     private:
202         array2d<unsigned char> feat_image;
203 
204         // This variable doesn't logically contribute to the state of this object.  It is here
205         // only to avoid returning a descriptor_type object by value inside the operator() method.
206         mutable descriptor_type feat;
207     };
208 
209 // ----------------------------------------------------------------------------------------
210 
211     template <
212         typename image_array_type
213         >
make_simple_test_data(image_array_type & images,std::vector<std::vector<rectangle>> & object_locations)214     void make_simple_test_data (
215         image_array_type& images,
216         std::vector<std::vector<rectangle> >& object_locations
217     )
218     {
219         images.clear();
220         object_locations.clear();
221 
222         images.resize(3);
223         images[0].set_size(400,400);
224         images[1].set_size(400,400);
225         images[2].set_size(400,400);
226 
227         // set all the pixel values to black
228         assign_all_pixels(images[0], 0);
229         assign_all_pixels(images[1], 0);
230         assign_all_pixels(images[2], 0);
231 
232         // Now make some squares and draw them onto our black images. All the
233         // squares will be 70 pixels wide and tall.
234 
235         std::vector<rectangle> temp;
236         temp.push_back(centered_rect(point(100,100), 70,70));
237         fill_rect(images[0],temp.back(),255); // Paint the square white
238         temp.push_back(centered_rect(point(200,300), 70,70));
239         fill_rect(images[0],temp.back(),255); // Paint the square white
240         object_locations.push_back(temp);
241 
242         temp.clear();
243         temp.push_back(centered_rect(point(140,200), 70,70));
244         fill_rect(images[1],temp.back(),255); // Paint the square white
245         temp.push_back(centered_rect(point(303,200), 70,70));
246         fill_rect(images[1],temp.back(),255); // Paint the square white
247         object_locations.push_back(temp);
248 
249         temp.clear();
250         temp.push_back(centered_rect(point(123,121), 70,70));
251         fill_rect(images[2],temp.back(),255); // Paint the square white
252         object_locations.push_back(temp);
253 
254         // corrupt each image with random noise just to make this a little more
255         // challenging
256         dlib::rand rnd;
257         for (unsigned long i = 0; i < images.size(); ++i)
258         {
259             for (long r = 0; r < images[i].nr(); ++r)
260             {
261                 for (long c = 0; c < images[i].nc(); ++c)
262                 {
263                     typedef typename image_array_type::type image_type;
264                     typedef typename image_type::type type;
265                     images[i][r][c] = (type)put_in_range(0,255,images[i][r][c] + 10*rnd.get_random_gaussian());
266                 }
267             }
268         }
269     }
270 
271     template <
272         typename image_array_type
273         >
make_simple_test_data(image_array_type & images,std::vector<std::vector<full_object_detection>> & object_locations)274     void make_simple_test_data (
275         image_array_type& images,
276         std::vector<std::vector<full_object_detection> >& object_locations
277     )
278     {
279         images.clear();
280         object_locations.clear();
281 
282 
283         images.resize(3);
284         images[0].set_size(400,400);
285         images[1].set_size(400,400);
286         images[2].set_size(400,400);
287 
288         // set all the pixel values to black
289         assign_all_pixels(images[0], 0);
290         assign_all_pixels(images[1], 0);
291         assign_all_pixels(images[2], 0);
292 
293         // Now make some squares and draw them onto our black images. All the
294         // squares will be 70 pixels wide and tall.
295         const int shrink = 0;
296         std::vector<full_object_detection> temp;
297 
298         rectangle rect = centered_rect(point(100,100), 70,71);
299         std::vector<point> movable_parts;
300         movable_parts.push_back(shrink_rect(rect,shrink).tl_corner());
301         movable_parts.push_back(shrink_rect(rect,shrink).tr_corner());
302         movable_parts.push_back(shrink_rect(rect,shrink).bl_corner());
303         movable_parts.push_back(shrink_rect(rect,shrink).br_corner());
304         temp.push_back(full_object_detection(rect, movable_parts));
305         fill_rect(images[0],rect,255); // Paint the square white
306 
307         rect = centered_rect(point(200,200), 70,71);
308         movable_parts.clear();
309         movable_parts.push_back(shrink_rect(rect,shrink).tl_corner());
310         movable_parts.push_back(shrink_rect(rect,shrink).tr_corner());
311         movable_parts.push_back(shrink_rect(rect,shrink).bl_corner());
312         movable_parts.push_back(shrink_rect(rect,shrink).br_corner());
313         temp.push_back(full_object_detection(rect, movable_parts));
314         fill_rect(images[0],rect,255); // Paint the square white
315 
316         object_locations.push_back(temp);
317         // ------------------------------------
318         temp.clear();
319 
320         rect = centered_rect(point(140,200), 70,71);
321         movable_parts.clear();
322         movable_parts.push_back(shrink_rect(rect,shrink).tl_corner());
323         movable_parts.push_back(shrink_rect(rect,shrink).tr_corner());
324         movable_parts.push_back(shrink_rect(rect,shrink).bl_corner());
325         movable_parts.push_back(shrink_rect(rect,shrink).br_corner());
326         temp.push_back(full_object_detection(rect, movable_parts));
327         fill_rect(images[1],rect,255); // Paint the square white
328 
329 
330         rect = centered_rect(point(303,200), 70,71);
331         movable_parts.clear();
332         movable_parts.push_back(shrink_rect(rect,shrink).tl_corner());
333         movable_parts.push_back(shrink_rect(rect,shrink).tr_corner());
334         movable_parts.push_back(shrink_rect(rect,shrink).bl_corner());
335         movable_parts.push_back(shrink_rect(rect,shrink).br_corner());
336         temp.push_back(full_object_detection(rect, movable_parts));
337         fill_rect(images[1],rect,255); // Paint the square white
338 
339         object_locations.push_back(temp);
340         // ------------------------------------
341         temp.clear();
342 
343         rect = centered_rect(point(123,121), 70,71);
344         movable_parts.clear();
345         movable_parts.push_back(shrink_rect(rect,shrink).tl_corner());
346         movable_parts.push_back(shrink_rect(rect,shrink).tr_corner());
347         movable_parts.push_back(shrink_rect(rect,shrink).bl_corner());
348         movable_parts.push_back(shrink_rect(rect,shrink).br_corner());
349         temp.push_back(full_object_detection(rect, movable_parts));
350         fill_rect(images[2],rect,255); // Paint the square white
351 
352         object_locations.push_back(temp);
353 
354         // corrupt each image with random noise just to make this a little more
355         // challenging
356         dlib::rand rnd;
357         for (unsigned long i = 0; i < images.size(); ++i)
358         {
359             for (long r = 0; r < images[i].nr(); ++r)
360             {
361                 for (long c = 0; c < images[i].nc(); ++c)
362                 {
363                     typedef typename image_array_type::type image_type;
364                     typedef typename image_type::type type;
365                     images[i][r][c] = (type)put_in_range(0,255,images[i][r][c] + 40*rnd.get_random_gaussian());
366                 }
367             }
368         }
369     }
370 
371 // ----------------------------------------------------------------------------------------
372 
test_fhog_pyramid()373     void test_fhog_pyramid (
374     )
375     {
376         print_spinner();
377         dlog << LINFO << "test_fhog_pyramid()";
378 
379         typedef dlib::array<array2d<unsigned char> >  grayscale_image_array_type;
380         grayscale_image_array_type images;
381         std::vector<std::vector<rectangle> > object_locations;
382         make_simple_test_data(images, object_locations);
383 
384         typedef scan_fhog_pyramid<pyramid_down<2> > image_scanner_type;
385         image_scanner_type scanner;
386         scanner.set_detection_window_size(35,35);
387         structural_object_detection_trainer<image_scanner_type> trainer(scanner);
388         trainer.set_num_threads(4);
389         trainer.set_overlap_tester(test_box_overlap(0,0));
390         object_detector<image_scanner_type> detector = trainer.train(images, object_locations);
391 
392         matrix<double> res = test_object_detection_function(detector, images, object_locations);
393         dlog << LINFO << "Test detector (precision,recall): " << res;
394         DLIB_TEST(sum(res) == 3);
395 
396         {
397             ostringstream sout;
398             serialize(detector, sout);
399             istringstream sin(sout.str());
400             object_detector<image_scanner_type> d2;
401             deserialize(d2, sin);
402             matrix<double> res = test_object_detection_function(d2, images, object_locations);
403             dlog << LINFO << "Test detector (precision,recall): " << res;
404             DLIB_TEST(sum(res) == 3);
405 
406             validate_some_object_detector_stuff(images, detector, 1e-6);
407         }
408 
409         {
410             std::vector<object_detector<image_scanner_type> > detectors;
411             detectors.push_back(detector);
412             detectors.push_back(detector);
413             detectors.push_back(detector);
414 
415             std::vector<rectangle> dets1 = evaluate_detectors(detectors, images[0]);
416             std::vector<rectangle> dets2 = detector(images[0]);
417             DLIB_TEST(dets1.size() > 0);
418             DLIB_TEST(dets2.size()*3 == dets1.size());
419             dlib::set<rectangle>::kernel_1a_c d1, d2;
420             for (unsigned long i = 0; i < dets1.size(); ++i)
421             {
422                 if (!d1.is_member(dets1[i]))
423                     d1.add(dets1[i]);
424             }
425             for (unsigned long i = 0; i < dets2.size(); ++i)
426             {
427                 if (!d2.is_member(dets2[i]))
428                     d2.add(dets2[i]);
429             }
430             DLIB_TEST(d1.size() == d2.size());
431             DLIB_TEST(set_intersection_size(d1,d2) == d1.size());
432         }
433     }
434 
435 // ----------------------------------------------------------------------------------------
436 
test_1()437     void test_1 (
438     )
439     {
440         print_spinner();
441         dlog << LINFO << "test_1()";
442 
443         typedef dlib::array<array2d<unsigned char> >  grayscale_image_array_type;
444         grayscale_image_array_type images;
445         std::vector<std::vector<rectangle> > object_locations;
446         make_simple_test_data(images, object_locations);
447 
448         typedef hashed_feature_image<hog_image<3,3,1,4,hog_signed_gradient,hog_full_interpolation> > feature_extractor_type;
449         typedef scan_image_pyramid<pyramid_down<2>, feature_extractor_type> image_scanner_type;
450         image_scanner_type scanner;
451         const rectangle object_box = compute_box_dimensions(1,35*35);
452         scanner.add_detection_template(object_box, create_grid_detection_template(object_box,2,2));
453         setup_hashed_features(scanner, images, 9);
454         use_uniform_feature_weights(scanner);
455         structural_object_detection_trainer<image_scanner_type> trainer(scanner);
456         trainer.set_num_threads(4);
457         trainer.set_overlap_tester(test_box_overlap(0,0));
458         object_detector<image_scanner_type> detector = trainer.train(images, object_locations);
459 
460         matrix<double> res = test_object_detection_function(detector, images, object_locations);
461         dlog << LINFO << "Test detector (precision,recall): " << res;
462         DLIB_TEST(sum(res) == 3);
463 
464         {
465             ostringstream sout;
466             serialize(detector, sout);
467             istringstream sin(sout.str());
468             object_detector<image_scanner_type> d2;
469             deserialize(d2, sin);
470             matrix<double> res = test_object_detection_function(d2, images, object_locations);
471             dlog << LINFO << "Test detector (precision,recall): " << res;
472             DLIB_TEST(sum(res) == 3);
473 
474             validate_some_object_detector_stuff(images, detector);
475         }
476     }
477 
478 // ----------------------------------------------------------------------------------------
479 
test_1_boxes()480     void test_1_boxes (
481     )
482     {
483         print_spinner();
484         dlog << LINFO << "test_1_boxes()";
485 
486         typedef dlib::array<array2d<unsigned char> >  grayscale_image_array_type;
487         grayscale_image_array_type images;
488         std::vector<std::vector<rectangle> > object_locations;
489         make_simple_test_data(images, object_locations);
490 
491         typedef hashed_feature_image<hog_image<3,3,1,4,hog_signed_gradient,hog_full_interpolation> > feature_extractor_type;
492         typedef scan_image_boxes<feature_extractor_type> image_scanner_type;
493         image_scanner_type scanner;
494         setup_hashed_features(scanner, images, 9);
495         use_uniform_feature_weights(scanner);
496         structural_object_detection_trainer<image_scanner_type> trainer(scanner);
497         trainer.set_num_threads(4);
498         trainer.set_overlap_tester(test_box_overlap(0,0));
499         object_detector<image_scanner_type> detector = trainer.train(images, object_locations);
500 
501         matrix<double> res = test_object_detection_function(detector, images, object_locations);
502         dlog << LINFO << "Test detector (precision,recall): " << res;
503         DLIB_TEST(sum(res) == 3);
504 
505         {
506             ostringstream sout;
507             serialize(detector, sout);
508             istringstream sin(sout.str());
509             object_detector<image_scanner_type> d2;
510             deserialize(d2, sin);
511             matrix<double> res = test_object_detection_function(d2, images, object_locations);
512             dlog << LINFO << "Test detector (precision,recall): " << res;
513             DLIB_TEST(sum(res) == 3);
514 
515             validate_some_object_detector_stuff(images, detector);
516         }
517     }
518 
519 // ----------------------------------------------------------------------------------------
520 
test_1m()521     void test_1m (
522     )
523     {
524         print_spinner();
525         dlog << LINFO << "test_1m()";
526 
527         typedef dlib::array<array2d<unsigned char> >  grayscale_image_array_type;
528         grayscale_image_array_type images;
529         std::vector<std::vector<full_object_detection> > object_locations;
530         make_simple_test_data(images, object_locations);
531 
532         typedef hashed_feature_image<hog_image<3,3,1,4,hog_signed_gradient,hog_full_interpolation> > feature_extractor_type;
533         typedef scan_image_pyramid<pyramid_down<2>, feature_extractor_type> image_scanner_type;
534         image_scanner_type scanner;
535         const rectangle object_box = compute_box_dimensions(1,35*35);
536         std::vector<rectangle> mboxes;
537         const int mbox_size = 20;
538         mboxes.push_back(centered_rect(0,0, mbox_size,mbox_size));
539         mboxes.push_back(centered_rect(0,0, mbox_size,mbox_size));
540         mboxes.push_back(centered_rect(0,0, mbox_size,mbox_size));
541         mboxes.push_back(centered_rect(0,0, mbox_size,mbox_size));
542         scanner.add_detection_template(object_box, create_grid_detection_template(object_box,1,1), mboxes);
543         setup_hashed_features(scanner, images, 9);
544         use_uniform_feature_weights(scanner);
545         structural_object_detection_trainer<image_scanner_type> trainer(scanner);
546         trainer.set_num_threads(4);
547         trainer.set_overlap_tester(test_box_overlap(0,0));
548         object_detector<image_scanner_type> detector = trainer.train(images, object_locations);
549 
550         matrix<double> res = test_object_detection_function(detector, images, object_locations);
551         dlog << LINFO << "Test detector (precision,recall): " << res;
552         DLIB_TEST(sum(res) == 3);
553 
554         {
555             ostringstream sout;
556             serialize(detector, sout);
557             istringstream sin(sout.str());
558             object_detector<image_scanner_type> d2;
559             deserialize(d2, sin);
560             matrix<double> res = test_object_detection_function(d2, images, object_locations);
561             dlog << LINFO << "Test detector (precision,recall): " << res;
562             DLIB_TEST(sum(res) == 3);
563 
564             validate_some_object_detector_stuff(images, detector);
565         }
566     }
567 
568 // ----------------------------------------------------------------------------------------
569 
test_1_fine_hog()570     void test_1_fine_hog (
571     )
572     {
573         print_spinner();
574         dlog << LINFO << "test_1_fine_hog()";
575 
576         typedef dlib::array<array2d<unsigned char> >  grayscale_image_array_type;
577         grayscale_image_array_type images;
578         std::vector<std::vector<rectangle> > object_locations;
579         make_simple_test_data(images, object_locations);
580 
581         typedef hashed_feature_image<fine_hog_image<3,3,2,4,hog_signed_gradient> > feature_extractor_type;
582         typedef scan_image_pyramid<pyramid_down<2>, feature_extractor_type> image_scanner_type;
583         image_scanner_type scanner;
584         const rectangle object_box = compute_box_dimensions(1,35*35);
585         scanner.add_detection_template(object_box, create_grid_detection_template(object_box,2,2));
586         setup_hashed_features(scanner, images, 9);
587         use_uniform_feature_weights(scanner);
588         structural_object_detection_trainer<image_scanner_type> trainer(scanner);
589         trainer.set_num_threads(4);
590         trainer.set_overlap_tester(test_box_overlap(0,0));
591         object_detector<image_scanner_type> detector = trainer.train(images, object_locations);
592 
593         matrix<double> res = test_object_detection_function(detector, images, object_locations);
594         dlog << LINFO << "Test detector (precision,recall): " << res;
595         DLIB_TEST(sum(res) == 3);
596 
597         {
598             ostringstream sout;
599             serialize(detector, sout);
600             istringstream sin(sout.str());
601             object_detector<image_scanner_type> d2;
602             deserialize(d2, sin);
603             matrix<double> res = test_object_detection_function(d2, images, object_locations);
604             dlog << LINFO << "Test detector (precision,recall): " << res;
605             DLIB_TEST(sum(res) == 3);
606 
607             validate_some_object_detector_stuff(images, detector);
608         }
609     }
610 
611 // ----------------------------------------------------------------------------------------
612 
test_1_poly()613     void test_1_poly (
614     )
615     {
616         print_spinner();
617         dlog << LINFO << "test_1_poly()";
618 
619         typedef dlib::array<array2d<unsigned char> >  grayscale_image_array_type;
620         grayscale_image_array_type images;
621         std::vector<std::vector<rectangle> > object_locations;
622         make_simple_test_data(images, object_locations);
623 
624         typedef hashed_feature_image<poly_image<2> > feature_extractor_type;
625         typedef scan_image_pyramid<pyramid_down<2>, feature_extractor_type> image_scanner_type;
626         image_scanner_type scanner;
627         const rectangle object_box = compute_box_dimensions(1,35*35);
628         scanner.add_detection_template(object_box, create_grid_detection_template(object_box,2,2));
629         setup_hashed_features(scanner, images, 9);
630         use_uniform_feature_weights(scanner);
631         structural_object_detection_trainer<image_scanner_type> trainer(scanner);
632         trainer.set_num_threads(4);
633         trainer.set_overlap_tester(test_box_overlap(0,0));
634         object_detector<image_scanner_type> detector = trainer.train(images, object_locations);
635 
636         matrix<double> res = test_object_detection_function(detector, images, object_locations);
637         dlog << LINFO << "Test detector (precision,recall): " << res;
638         DLIB_TEST(sum(res) == 3);
639 
640         {
641             ostringstream sout;
642             serialize(detector, sout);
643             istringstream sin(sout.str());
644             object_detector<image_scanner_type> d2;
645             deserialize(d2, sin);
646             matrix<double> res = test_object_detection_function(d2, images, object_locations);
647             dlog << LINFO << "Test detector (precision,recall): " << res;
648             DLIB_TEST(sum(res) == 3);
649 
650             validate_some_object_detector_stuff(images, detector);
651         }
652     }
653 
654 // ----------------------------------------------------------------------------------------
655 
test_1m_poly()656     void test_1m_poly (
657     )
658     {
659         print_spinner();
660         dlog << LINFO << "test_1_poly()";
661 
662         typedef dlib::array<array2d<unsigned char> >  grayscale_image_array_type;
663         grayscale_image_array_type images;
664         std::vector<std::vector<full_object_detection> > object_locations;
665         make_simple_test_data(images, object_locations);
666 
667         typedef hashed_feature_image<poly_image<2> > feature_extractor_type;
668         typedef scan_image_pyramid<pyramid_down<3>, feature_extractor_type> image_scanner_type;
669         image_scanner_type scanner;
670         const rectangle object_box = compute_box_dimensions(1,35*35);
671         std::vector<rectangle> mboxes;
672         const int mbox_size = 20;
673         mboxes.push_back(centered_rect(0,0, mbox_size,mbox_size));
674         mboxes.push_back(centered_rect(0,0, mbox_size,mbox_size));
675         mboxes.push_back(centered_rect(0,0, mbox_size,mbox_size));
676         mboxes.push_back(centered_rect(0,0, mbox_size,mbox_size));
677         scanner.add_detection_template(object_box, create_grid_detection_template(object_box,2,2), mboxes);
678         setup_hashed_features(scanner, images, 9);
679         use_uniform_feature_weights(scanner);
680         structural_object_detection_trainer<image_scanner_type> trainer(scanner);
681         trainer.set_num_threads(4);
682         trainer.set_overlap_tester(test_box_overlap(0,0));
683         object_detector<image_scanner_type> detector = trainer.train(images, object_locations);
684 
685         matrix<double> res = test_object_detection_function(detector, images, object_locations);
686         dlog << LINFO << "Test detector (precision,recall): " << res;
687         DLIB_TEST(sum(res) == 3);
688 
689         {
690             ostringstream sout;
691             serialize(detector, sout);
692             istringstream sin(sout.str());
693             object_detector<image_scanner_type> d2;
694             deserialize(d2, sin);
695             matrix<double> res = test_object_detection_function(d2, images, object_locations);
696             dlog << LINFO << "Test detector (precision,recall): " << res;
697             DLIB_TEST(sum(res) == 3);
698 
699             validate_some_object_detector_stuff(images, detector);
700         }
701     }
702 
703 // ----------------------------------------------------------------------------------------
704 
test_1_poly_nn()705     void test_1_poly_nn (
706     )
707     {
708         print_spinner();
709         dlog << LINFO << "test_1_poly_nn()";
710 
711         typedef dlib::array<array2d<unsigned char> >  grayscale_image_array_type;
712         grayscale_image_array_type images;
713         std::vector<std::vector<rectangle> > object_locations;
714         make_simple_test_data(images, object_locations);
715 
716         typedef nearest_neighbor_feature_image<poly_image<5> > feature_extractor_type;
717         typedef scan_image_pyramid<pyramid_down<2>, feature_extractor_type> image_scanner_type;
718         image_scanner_type scanner;
719 
720         setup_grid_detection_templates(scanner, object_locations, 2, 2);
721         feature_extractor_type nnfe;
722         pyramid_down<2> pyr_down;
723         poly_image<5> polyi;
724         nnfe.set_basis(randomly_sample_image_features(images, pyr_down, polyi, 80));
725         scanner.copy_configuration(nnfe);
726 
727         structural_object_detection_trainer<image_scanner_type> trainer(scanner);
728         trainer.set_num_threads(4);
729         object_detector<image_scanner_type> detector = trainer.train(images, object_locations);
730 
731         matrix<double> res = test_object_detection_function(detector, images, object_locations);
732         dlog << LINFO << "Test detector (precision,recall): " << res;
733         DLIB_TEST(sum(res) == 3);
734 
735         {
736             ostringstream sout;
737             serialize(detector, sout);
738             istringstream sin(sout.str());
739             object_detector<image_scanner_type> d2;
740             deserialize(d2, sin);
741             matrix<double> res = test_object_detection_function(d2, images, object_locations);
742             dlog << LINFO << "Test detector (precision,recall): " << res;
743             DLIB_TEST(sum(res) == 3);
744 
745             validate_some_object_detector_stuff(images, detector);
746         }
747     }
748 
749 // ----------------------------------------------------------------------------------------
750 
test_1_poly_nn_boxes()751     void test_1_poly_nn_boxes (
752     )
753     {
754         print_spinner();
755         dlog << LINFO << "test_1_poly_nn_boxes()";
756 
757         typedef dlib::array<array2d<unsigned char> >  grayscale_image_array_type;
758         grayscale_image_array_type images;
759         std::vector<std::vector<rectangle> > object_locations;
760         make_simple_test_data(images, object_locations);
761 
762         typedef nearest_neighbor_feature_image<poly_image<5> > feature_extractor_type;
763         typedef scan_image_boxes<feature_extractor_type> image_scanner_type;
764         image_scanner_type scanner;
765 
766         feature_extractor_type nnfe;
767         pyramid_down<2> pyr_down;
768         poly_image<5> polyi;
769         nnfe.set_basis(randomly_sample_image_features(images, pyr_down, polyi, 80));
770         scanner.copy_configuration(nnfe);
771 
772         structural_object_detection_trainer<image_scanner_type> trainer(scanner);
773         trainer.set_num_threads(4);
774         object_detector<image_scanner_type> detector = trainer.train(images, object_locations);
775 
776         matrix<double> res = test_object_detection_function(detector, images, object_locations);
777         dlog << LINFO << "Test detector (precision,recall): " << res;
778         DLIB_TEST(sum(res) == 3);
779 
780         {
781             ostringstream sout;
782             serialize(detector, sout);
783             istringstream sin(sout.str());
784             object_detector<image_scanner_type> d2;
785             deserialize(d2, sin);
786             matrix<double> res = test_object_detection_function(d2, images, object_locations);
787             dlog << LINFO << "Test detector (precision,recall): " << res;
788             DLIB_TEST(sum(res) == 3);
789 
790             validate_some_object_detector_stuff(images, detector);
791         }
792     }
793 
794 // ----------------------------------------------------------------------------------------
795 
test_2()796     void test_2 (
797     )
798     {
799         print_spinner();
800         dlog << LINFO << "test_2()";
801 
802         typedef dlib::array<array2d<unsigned char> >  grayscale_image_array_type;
803         grayscale_image_array_type images;
804         std::vector<std::vector<rectangle> > object_locations;
805         make_simple_test_data(images, object_locations);
806 
807         typedef scan_image_pyramid<pyramid_down<5>, very_simple_feature_extractor> image_scanner_type;
808         image_scanner_type scanner;
809         const rectangle object_box = compute_box_dimensions(1,70*70);
810         scanner.add_detection_template(object_box, create_grid_detection_template(object_box,2,2));
811         scanner.set_max_pyramid_levels(1);
812         structural_object_detection_trainer<image_scanner_type> trainer(scanner);
813         trainer.set_num_threads(0);
814         object_detector<image_scanner_type> detector = trainer.train(images, object_locations);
815 
816         matrix<double> res = test_object_detection_function(detector, images, object_locations);
817         dlog << LINFO << "Test detector (precision,recall): " << res;
818         DLIB_TEST(sum(res) == 3);
819 
820         res = cross_validate_object_detection_trainer(trainer, images, object_locations, 3);
821         dlog << LINFO << "3-fold cross validation (precision,recall): " << res;
822         DLIB_TEST(sum(res) == 3);
823 
824         {
825             ostringstream sout;
826             serialize(detector, sout);
827             istringstream sin(sout.str());
828             object_detector<image_scanner_type> d2;
829             deserialize(d2, sin);
830             matrix<double> res = test_object_detection_function(d2, images, object_locations);
831             dlog << LINFO << "Test detector (precision,recall): " << res;
832             DLIB_TEST(sum(res) == 3);
833             validate_some_object_detector_stuff(images, detector);
834         }
835     }
836 
837 // ----------------------------------------------------------------------------------------
838 
839     class pyramid_down_funny : noncopyable
840     {
841         pyramid_down<2> pyr;
842     public:
843 
844         template <typename T>
point_down(const dlib::vector<T,2> & p) const845         dlib::vector<double,2> point_down ( const dlib::vector<T,2>& p) const { return pyr.point_down(p); }
846 
847         template <typename T>
point_up(const dlib::vector<T,2> & p) const848         dlib::vector<double,2> point_up ( const dlib::vector<T,2>& p) const { return pyr.point_up(p); }
849 
850         template <typename T>
point_down(const dlib::vector<T,2> & p,unsigned int levels) const851         dlib::vector<double,2> point_down ( const dlib::vector<T,2>& p, unsigned int levels) const { return pyr.point_down(p,levels); }
852 
853         template <typename T>
point_up(const dlib::vector<T,2> & p,unsigned int levels) const854         dlib::vector<double,2> point_up ( const dlib::vector<T,2>& p, unsigned int levels) const { return pyr.point_up(p,levels); }
855 
rect_up(const rectangle & rect) const856         rectangle rect_up ( const rectangle& rect) const { return pyr.rect_up(rect); }
857 
rect_up(const rectangle & rect,unsigned int levels) const858         rectangle rect_up ( const rectangle& rect, unsigned int levels) const { return pyr.rect_up(rect,levels); }
859 
rect_down(const rectangle & rect) const860         rectangle rect_down ( const rectangle& rect) const { return pyr.rect_down(rect); }
861 
rect_down(const rectangle & rect,unsigned int levels) const862         rectangle rect_down ( const rectangle& rect, unsigned int levels) const { return pyr.rect_down(rect,levels); }
863 
864         template <
865             typename in_image_type,
866             typename out_image_type
867             >
operator ()(const in_image_type & original,out_image_type & down) const868         void operator() (
869             const in_image_type& original,
870             out_image_type& down
871         ) const
872         {
873             pyr(original.img, down.img);
874         }
875 
876     };
877 
878     // make sure everything works even when the image isn't a dlib::array2d.
879     // So test with funny_image.
test_3()880     void test_3 (
881     )
882     {
883         print_spinner();
884         dlog << LINFO << "test_3()";
885 
886 
887         typedef dlib::array<array2d<unsigned char> >  grayscale_image_array_type;
888         typedef dlib::array<funny_image>  funny_image_array_type;
889         grayscale_image_array_type images_temp;
890         funny_image_array_type images;
891         std::vector<std::vector<rectangle> > object_locations;
892         make_simple_test_data(images_temp, object_locations);
893         images.resize(images_temp.size());
894         for (unsigned long i = 0; i < images_temp.size(); ++i)
895         {
896             images[i].img.swap(images_temp[i]);
897         }
898 
899         typedef scan_image_pyramid<pyramid_down_funny, very_simple_feature_extractor> image_scanner_type;
900         image_scanner_type scanner;
901         const rectangle object_box = compute_box_dimensions(1,70*70);
902         scanner.add_detection_template(object_box, create_grid_detection_template(object_box,2,2));
903         scanner.set_max_pyramid_levels(1);
904         structural_object_detection_trainer<image_scanner_type> trainer(scanner);
905         trainer.set_num_threads(4);
906         object_detector<image_scanner_type> detector = trainer.train(images, object_locations);
907 
908         matrix<double> res = test_object_detection_function(detector, images, object_locations);
909         dlog << LINFO << "Test detector (precision,recall): " << res;
910         DLIB_TEST(sum(res) == 3);
911 
912         res = cross_validate_object_detection_trainer(trainer, images, object_locations, 3);
913         dlog << LINFO << "3-fold cross validation (precision,recall): " << res;
914         DLIB_TEST(sum(res) == 3);
915 
916         {
917             ostringstream sout;
918             serialize(detector, sout);
919             istringstream sin(sout.str());
920             object_detector<image_scanner_type> d2;
921             deserialize(d2, sin);
922             matrix<double> res = test_object_detection_function(d2, images, object_locations);
923             dlog << LINFO << "Test detector (precision,recall): " << res;
924             DLIB_TEST(sum(res) == 3);
925         }
926     }
927 
928 // ----------------------------------------------------------------------------------------
929 
930     class funny_box_generator
931     {
932     public:
933         template <typename image_type>
operator ()(const image_type & img,std::vector<rectangle> & rects) const934         void operator() (
935             const image_type& img,
936             std::vector<rectangle>& rects
937         ) const
938         {
939             rects.clear();
940             find_candidate_object_locations(img.img, rects);
941             dlog << LINFO << "funny_box_generator, rects.size(): "<< rects.size();
942         }
943     };
944 
serialize(const funny_box_generator &,std::ostream &)945     inline void serialize(const funny_box_generator&, std::ostream& ) {}
deserialize(funny_box_generator &,std::istream &)946     inline void deserialize(funny_box_generator&, std::istream& ) {}
947 
948 
949     // make sure everything works even when the image isn't a dlib::array2d.
950     // So test with funny_image.
test_3_boxes()951     void test_3_boxes (
952     )
953     {
954         print_spinner();
955         dlog << LINFO << "test_3_boxes()";
956 
957 
958         typedef dlib::array<array2d<unsigned char> >  grayscale_image_array_type;
959         typedef dlib::array<funny_image>  funny_image_array_type;
960         grayscale_image_array_type images_temp;
961         funny_image_array_type images;
962         std::vector<std::vector<rectangle> > object_locations;
963         make_simple_test_data(images_temp, object_locations);
964         images.resize(images_temp.size());
965         for (unsigned long i = 0; i < images_temp.size(); ++i)
966         {
967             images[i].img.swap(images_temp[i]);
968         }
969 
970         typedef scan_image_boxes<very_simple_feature_extractor, funny_box_generator> image_scanner_type;
971         image_scanner_type scanner;
972         structural_object_detection_trainer<image_scanner_type> trainer(scanner);
973         trainer.set_num_threads(4);
974         object_detector<image_scanner_type> detector = trainer.train(images, object_locations);
975 
976         matrix<double> res = test_object_detection_function(detector, images, object_locations);
977         dlog << LINFO << "Test detector (precision,recall): " << res;
978         DLIB_TEST(sum(res) == 3);
979 
980         res = cross_validate_object_detection_trainer(trainer, images, object_locations, 3);
981         dlog << LINFO << "3-fold cross validation (precision,recall): " << res;
982         DLIB_TEST(sum(res) == 3);
983 
984         {
985             ostringstream sout;
986             serialize(detector, sout);
987             istringstream sin(sout.str());
988             object_detector<image_scanner_type> d2;
989             deserialize(d2, sin);
990             matrix<double> res = test_object_detection_function(d2, images, object_locations);
991             dlog << LINFO << "Test detector (precision,recall): " << res;
992             DLIB_TEST(sum(res) == 3);
993         }
994     }
995 
996 // ----------------------------------------------------------------------------------------
997 
998     class object_detector_tester : public tester
999     {
1000     public:
object_detector_tester()1001         object_detector_tester (
1002         ) :
1003             tester ("test_object_detector",
1004                     "Runs tests on the structural object detection stuff.")
1005         {}
1006 
perform_test()1007         void perform_test (
1008         )
1009         {
1010             test_fhog_pyramid();
1011             test_1_boxes();
1012             test_1_poly_nn_boxes();
1013             test_3_boxes();
1014 
1015             test_1();
1016             test_1m();
1017             test_1_fine_hog();
1018             test_1_poly();
1019             test_1m_poly();
1020             test_1_poly_nn();
1021             test_2();
1022             test_3();
1023         }
1024     } a;
1025 
1026 }
1027 
1028 
1029