1 // Copyright (C) 2011 Davis E. King (davis@dlib.net) 2 // License: Boost Software License See LICENSE.txt for the full license. 3 4 5 #include <dlib/statistics.h> 6 #include <sstream> 7 #include <string> 8 #include <cstdlib> 9 #include <ctime> 10 #include "tester.h" 11 #include <dlib/pixel.h> 12 #include <dlib/svm_threaded.h> 13 #include <dlib/array.h> 14 #include <dlib/set_utils.h> 15 #include <dlib/array2d.h> 16 #include <dlib/image_keypoint.h> 17 #include <dlib/image_processing.h> 18 #include <dlib/image_transforms.h> 19 20 namespace 21 { 22 using namespace test; 23 using namespace dlib; 24 using namespace std; 25 26 logger dlog("test.object_detector"); 27 28 // ---------------------------------------------------------------------------------------- 29 30 struct funny_image 31 { 32 array2d<unsigned char> img; nr__anon62e9091a0111::funny_image33 long nr() const { return img.nr(); } nc__anon62e9091a0111::funny_image34 long nc() const { return img.nc(); } 35 }; 36 swap(funny_image & a,funny_image & b)37 void swap(funny_image& a, funny_image& b) 38 { 39 a.img.swap(b.img); 40 } 41 42 // ---------------------------------------------------------------------------------------- 43 44 template < 45 typename image_array_type, 46 typename detector_type 47 > validate_some_object_detector_stuff(const image_array_type & images,detector_type & detector,double eps=1e-10)48 void validate_some_object_detector_stuff ( 49 const image_array_type& images, 50 detector_type& detector, 51 double eps = 1e-10 52 ) 53 { 54 for (unsigned long i = 0; i < images.size(); ++i) 55 { 56 std::vector<rectangle> dets = detector(images[i]); 57 std::vector<std::pair<double,rectangle> > dets2; 58 59 detector(images[i], dets2); 60 61 matrix<double,0,1> psi(detector.get_w().size()); 62 matrix<double,0,1> psi2(detector.get_w().size()); 63 const double thresh = detector.get_w()(detector.get_w().size()-1); 64 65 DLIB_TEST(dets.size() == dets2.size()); 66 for (unsigned long j = 0; j < dets.size(); ++j) 67 { 68 DLIB_TEST(dets[j] == dets2[j].second); 69 70 const full_object_detection fdet = detector.get_scanner().get_full_object_detection(dets[j], detector.get_w()); 71 psi = 0; 72 detector.get_scanner().get_feature_vector(fdet, psi); 73 74 double check_score = dot(psi,detector.get_w()) - thresh; 75 DLIB_TEST_MSG(std::abs(check_score - dets2[j].first) < eps, std::abs(check_score - dets2[j].first) << " check_score: "<< check_score); 76 } 77 78 } 79 } 80 81 // ---------------------------------------------------------------------------------------- 82 83 class very_simple_feature_extractor : noncopyable 84 { 85 /*! 86 WHAT THIS OBJECT REPRESENTS 87 This object is a feature extractor which goes to every pixel in an image and 88 produces a 32 dimensional feature vector. This vector is an indicator vector 89 which records the pattern of pixel values in a 4-connected region. So it should 90 be able to distinguish basic things like whether or not a location falls on the 91 corner of a white box, on an edge, in the middle, etc. 92 93 94 Note that this object also implements the interface defined in dlib/image_keypoint/hashed_feature_image_abstract.h. 95 This means all the member functions in this object are supposed to behave as 96 described in the hashed_feature_image specification. So when you define your own 97 feature extractor objects you should probably refer yourself to that documentation 98 in addition to reading this example program. 99 !*/ 100 101 102 public: 103 load(const funny_image & img_)104 inline void load ( 105 const funny_image& img_ 106 ) 107 { 108 const array2d<unsigned char>& img = img_.img; 109 110 feat_image.set_size(img.nr(), img.nc()); 111 assign_all_pixels(feat_image,0); 112 for (long r = 1; r+1 < img.nr(); ++r) 113 { 114 for (long c = 1; c+1 < img.nc(); ++c) 115 { 116 unsigned char f = 0; 117 if (img[r][c]) f |= 0x1; 118 if (img[r][c+1]) f |= 0x2; 119 if (img[r][c-1]) f |= 0x4; 120 if (img[r+1][c]) f |= 0x8; 121 if (img[r-1][c]) f |= 0x10; 122 123 // Store the code value for the pattern of pixel values in the 4-connected 124 // neighborhood around this row and column. 125 feat_image[r][c] = f; 126 } 127 } 128 } 129 load(const array2d<unsigned char> & img)130 inline void load ( 131 const array2d<unsigned char>& img 132 ) 133 { 134 feat_image.set_size(img.nr(), img.nc()); 135 assign_all_pixels(feat_image,0); 136 for (long r = 1; r+1 < img.nr(); ++r) 137 { 138 for (long c = 1; c+1 < img.nc(); ++c) 139 { 140 unsigned char f = 0; 141 if (img[r][c]) f |= 0x1; 142 if (img[r][c+1]) f |= 0x2; 143 if (img[r][c-1]) f |= 0x4; 144 if (img[r+1][c]) f |= 0x8; 145 if (img[r-1][c]) f |= 0x10; 146 147 // Store the code value for the pattern of pixel values in the 4-connected 148 // neighborhood around this row and column. 149 feat_image[r][c] = f; 150 } 151 } 152 } 153 size() const154 inline size_t size () const { return feat_image.size(); } nr() const155 inline long nr () const { return feat_image.nr(); } nc() const156 inline long nc () const { return feat_image.nc(); } 157 get_num_dimensions() const158 inline long get_num_dimensions ( 159 ) const 160 { 161 // Return the dimensionality of the vectors produced by operator() 162 return 32; 163 } 164 165 typedef std::vector<std::pair<unsigned int,double> > descriptor_type; 166 operator ()(long row,long col) const167 inline const descriptor_type& operator() ( 168 long row, 169 long col 170 ) const 171 /*! 172 requires 173 - 0 <= row < nr() 174 - 0 <= col < nc() 175 ensures 176 - returns a sparse vector which describes the image at the given row and column. 177 In particular, this is a vector that is 0 everywhere except for one element. 178 !*/ 179 { 180 feat.clear(); 181 const unsigned long only_nonzero_element_index = feat_image[row][col]; 182 feat.push_back(make_pair(only_nonzero_element_index,1.0)); 183 return feat; 184 } 185 186 // This block of functions is meant to provide a way to map between the row/col space taken by 187 // this object's operator() function and the images supplied to load(). In this example it's trivial. 188 // However, in general, you might create feature extractors which don't perform extraction at every 189 // possible image location (e.g. the hog_image) and thus result in some more complex mapping. get_block_rect(long row,long col) const190 inline const rectangle get_block_rect ( long row, long col) const { return centered_rect(col,row,3,3); } image_to_feat_space(const point & p) const191 inline const point image_to_feat_space ( const point& p) const { return p; } image_to_feat_space(const rectangle & rect) const192 inline const rectangle image_to_feat_space ( const rectangle& rect) const { return rect; } feat_to_image_space(const point & p) const193 inline const point feat_to_image_space ( const point& p) const { return p; } feat_to_image_space(const rectangle & rect) const194 inline const rectangle feat_to_image_space ( const rectangle& rect) const { return rect; } 195 serialize(const very_simple_feature_extractor & item,std::ostream & out)196 inline friend void serialize ( const very_simple_feature_extractor& item, std::ostream& out) { serialize(item.feat_image, out); } deserialize(very_simple_feature_extractor & item,std::istream & in)197 inline friend void deserialize ( very_simple_feature_extractor& item, std::istream& in ) { deserialize(item.feat_image, in); } 198 copy_configuration(const very_simple_feature_extractor &)199 void copy_configuration ( const very_simple_feature_extractor& ){} 200 201 private: 202 array2d<unsigned char> feat_image; 203 204 // This variable doesn't logically contribute to the state of this object. It is here 205 // only to avoid returning a descriptor_type object by value inside the operator() method. 206 mutable descriptor_type feat; 207 }; 208 209 // ---------------------------------------------------------------------------------------- 210 211 template < 212 typename image_array_type 213 > make_simple_test_data(image_array_type & images,std::vector<std::vector<rectangle>> & object_locations)214 void make_simple_test_data ( 215 image_array_type& images, 216 std::vector<std::vector<rectangle> >& object_locations 217 ) 218 { 219 images.clear(); 220 object_locations.clear(); 221 222 images.resize(3); 223 images[0].set_size(400,400); 224 images[1].set_size(400,400); 225 images[2].set_size(400,400); 226 227 // set all the pixel values to black 228 assign_all_pixels(images[0], 0); 229 assign_all_pixels(images[1], 0); 230 assign_all_pixels(images[2], 0); 231 232 // Now make some squares and draw them onto our black images. All the 233 // squares will be 70 pixels wide and tall. 234 235 std::vector<rectangle> temp; 236 temp.push_back(centered_rect(point(100,100), 70,70)); 237 fill_rect(images[0],temp.back(),255); // Paint the square white 238 temp.push_back(centered_rect(point(200,300), 70,70)); 239 fill_rect(images[0],temp.back(),255); // Paint the square white 240 object_locations.push_back(temp); 241 242 temp.clear(); 243 temp.push_back(centered_rect(point(140,200), 70,70)); 244 fill_rect(images[1],temp.back(),255); // Paint the square white 245 temp.push_back(centered_rect(point(303,200), 70,70)); 246 fill_rect(images[1],temp.back(),255); // Paint the square white 247 object_locations.push_back(temp); 248 249 temp.clear(); 250 temp.push_back(centered_rect(point(123,121), 70,70)); 251 fill_rect(images[2],temp.back(),255); // Paint the square white 252 object_locations.push_back(temp); 253 254 // corrupt each image with random noise just to make this a little more 255 // challenging 256 dlib::rand rnd; 257 for (unsigned long i = 0; i < images.size(); ++i) 258 { 259 for (long r = 0; r < images[i].nr(); ++r) 260 { 261 for (long c = 0; c < images[i].nc(); ++c) 262 { 263 typedef typename image_array_type::type image_type; 264 typedef typename image_type::type type; 265 images[i][r][c] = (type)put_in_range(0,255,images[i][r][c] + 10*rnd.get_random_gaussian()); 266 } 267 } 268 } 269 } 270 271 template < 272 typename image_array_type 273 > make_simple_test_data(image_array_type & images,std::vector<std::vector<full_object_detection>> & object_locations)274 void make_simple_test_data ( 275 image_array_type& images, 276 std::vector<std::vector<full_object_detection> >& object_locations 277 ) 278 { 279 images.clear(); 280 object_locations.clear(); 281 282 283 images.resize(3); 284 images[0].set_size(400,400); 285 images[1].set_size(400,400); 286 images[2].set_size(400,400); 287 288 // set all the pixel values to black 289 assign_all_pixels(images[0], 0); 290 assign_all_pixels(images[1], 0); 291 assign_all_pixels(images[2], 0); 292 293 // Now make some squares and draw them onto our black images. All the 294 // squares will be 70 pixels wide and tall. 295 const int shrink = 0; 296 std::vector<full_object_detection> temp; 297 298 rectangle rect = centered_rect(point(100,100), 70,71); 299 std::vector<point> movable_parts; 300 movable_parts.push_back(shrink_rect(rect,shrink).tl_corner()); 301 movable_parts.push_back(shrink_rect(rect,shrink).tr_corner()); 302 movable_parts.push_back(shrink_rect(rect,shrink).bl_corner()); 303 movable_parts.push_back(shrink_rect(rect,shrink).br_corner()); 304 temp.push_back(full_object_detection(rect, movable_parts)); 305 fill_rect(images[0],rect,255); // Paint the square white 306 307 rect = centered_rect(point(200,200), 70,71); 308 movable_parts.clear(); 309 movable_parts.push_back(shrink_rect(rect,shrink).tl_corner()); 310 movable_parts.push_back(shrink_rect(rect,shrink).tr_corner()); 311 movable_parts.push_back(shrink_rect(rect,shrink).bl_corner()); 312 movable_parts.push_back(shrink_rect(rect,shrink).br_corner()); 313 temp.push_back(full_object_detection(rect, movable_parts)); 314 fill_rect(images[0],rect,255); // Paint the square white 315 316 object_locations.push_back(temp); 317 // ------------------------------------ 318 temp.clear(); 319 320 rect = centered_rect(point(140,200), 70,71); 321 movable_parts.clear(); 322 movable_parts.push_back(shrink_rect(rect,shrink).tl_corner()); 323 movable_parts.push_back(shrink_rect(rect,shrink).tr_corner()); 324 movable_parts.push_back(shrink_rect(rect,shrink).bl_corner()); 325 movable_parts.push_back(shrink_rect(rect,shrink).br_corner()); 326 temp.push_back(full_object_detection(rect, movable_parts)); 327 fill_rect(images[1],rect,255); // Paint the square white 328 329 330 rect = centered_rect(point(303,200), 70,71); 331 movable_parts.clear(); 332 movable_parts.push_back(shrink_rect(rect,shrink).tl_corner()); 333 movable_parts.push_back(shrink_rect(rect,shrink).tr_corner()); 334 movable_parts.push_back(shrink_rect(rect,shrink).bl_corner()); 335 movable_parts.push_back(shrink_rect(rect,shrink).br_corner()); 336 temp.push_back(full_object_detection(rect, movable_parts)); 337 fill_rect(images[1],rect,255); // Paint the square white 338 339 object_locations.push_back(temp); 340 // ------------------------------------ 341 temp.clear(); 342 343 rect = centered_rect(point(123,121), 70,71); 344 movable_parts.clear(); 345 movable_parts.push_back(shrink_rect(rect,shrink).tl_corner()); 346 movable_parts.push_back(shrink_rect(rect,shrink).tr_corner()); 347 movable_parts.push_back(shrink_rect(rect,shrink).bl_corner()); 348 movable_parts.push_back(shrink_rect(rect,shrink).br_corner()); 349 temp.push_back(full_object_detection(rect, movable_parts)); 350 fill_rect(images[2],rect,255); // Paint the square white 351 352 object_locations.push_back(temp); 353 354 // corrupt each image with random noise just to make this a little more 355 // challenging 356 dlib::rand rnd; 357 for (unsigned long i = 0; i < images.size(); ++i) 358 { 359 for (long r = 0; r < images[i].nr(); ++r) 360 { 361 for (long c = 0; c < images[i].nc(); ++c) 362 { 363 typedef typename image_array_type::type image_type; 364 typedef typename image_type::type type; 365 images[i][r][c] = (type)put_in_range(0,255,images[i][r][c] + 40*rnd.get_random_gaussian()); 366 } 367 } 368 } 369 } 370 371 // ---------------------------------------------------------------------------------------- 372 test_fhog_pyramid()373 void test_fhog_pyramid ( 374 ) 375 { 376 print_spinner(); 377 dlog << LINFO << "test_fhog_pyramid()"; 378 379 typedef dlib::array<array2d<unsigned char> > grayscale_image_array_type; 380 grayscale_image_array_type images; 381 std::vector<std::vector<rectangle> > object_locations; 382 make_simple_test_data(images, object_locations); 383 384 typedef scan_fhog_pyramid<pyramid_down<2> > image_scanner_type; 385 image_scanner_type scanner; 386 scanner.set_detection_window_size(35,35); 387 structural_object_detection_trainer<image_scanner_type> trainer(scanner); 388 trainer.set_num_threads(4); 389 trainer.set_overlap_tester(test_box_overlap(0,0)); 390 object_detector<image_scanner_type> detector = trainer.train(images, object_locations); 391 392 matrix<double> res = test_object_detection_function(detector, images, object_locations); 393 dlog << LINFO << "Test detector (precision,recall): " << res; 394 DLIB_TEST(sum(res) == 3); 395 396 { 397 ostringstream sout; 398 serialize(detector, sout); 399 istringstream sin(sout.str()); 400 object_detector<image_scanner_type> d2; 401 deserialize(d2, sin); 402 matrix<double> res = test_object_detection_function(d2, images, object_locations); 403 dlog << LINFO << "Test detector (precision,recall): " << res; 404 DLIB_TEST(sum(res) == 3); 405 406 validate_some_object_detector_stuff(images, detector, 1e-6); 407 } 408 409 { 410 std::vector<object_detector<image_scanner_type> > detectors; 411 detectors.push_back(detector); 412 detectors.push_back(detector); 413 detectors.push_back(detector); 414 415 std::vector<rectangle> dets1 = evaluate_detectors(detectors, images[0]); 416 std::vector<rectangle> dets2 = detector(images[0]); 417 DLIB_TEST(dets1.size() > 0); 418 DLIB_TEST(dets2.size()*3 == dets1.size()); 419 dlib::set<rectangle>::kernel_1a_c d1, d2; 420 for (unsigned long i = 0; i < dets1.size(); ++i) 421 { 422 if (!d1.is_member(dets1[i])) 423 d1.add(dets1[i]); 424 } 425 for (unsigned long i = 0; i < dets2.size(); ++i) 426 { 427 if (!d2.is_member(dets2[i])) 428 d2.add(dets2[i]); 429 } 430 DLIB_TEST(d1.size() == d2.size()); 431 DLIB_TEST(set_intersection_size(d1,d2) == d1.size()); 432 } 433 } 434 435 // ---------------------------------------------------------------------------------------- 436 test_1()437 void test_1 ( 438 ) 439 { 440 print_spinner(); 441 dlog << LINFO << "test_1()"; 442 443 typedef dlib::array<array2d<unsigned char> > grayscale_image_array_type; 444 grayscale_image_array_type images; 445 std::vector<std::vector<rectangle> > object_locations; 446 make_simple_test_data(images, object_locations); 447 448 typedef hashed_feature_image<hog_image<3,3,1,4,hog_signed_gradient,hog_full_interpolation> > feature_extractor_type; 449 typedef scan_image_pyramid<pyramid_down<2>, feature_extractor_type> image_scanner_type; 450 image_scanner_type scanner; 451 const rectangle object_box = compute_box_dimensions(1,35*35); 452 scanner.add_detection_template(object_box, create_grid_detection_template(object_box,2,2)); 453 setup_hashed_features(scanner, images, 9); 454 use_uniform_feature_weights(scanner); 455 structural_object_detection_trainer<image_scanner_type> trainer(scanner); 456 trainer.set_num_threads(4); 457 trainer.set_overlap_tester(test_box_overlap(0,0)); 458 object_detector<image_scanner_type> detector = trainer.train(images, object_locations); 459 460 matrix<double> res = test_object_detection_function(detector, images, object_locations); 461 dlog << LINFO << "Test detector (precision,recall): " << res; 462 DLIB_TEST(sum(res) == 3); 463 464 { 465 ostringstream sout; 466 serialize(detector, sout); 467 istringstream sin(sout.str()); 468 object_detector<image_scanner_type> d2; 469 deserialize(d2, sin); 470 matrix<double> res = test_object_detection_function(d2, images, object_locations); 471 dlog << LINFO << "Test detector (precision,recall): " << res; 472 DLIB_TEST(sum(res) == 3); 473 474 validate_some_object_detector_stuff(images, detector); 475 } 476 } 477 478 // ---------------------------------------------------------------------------------------- 479 test_1_boxes()480 void test_1_boxes ( 481 ) 482 { 483 print_spinner(); 484 dlog << LINFO << "test_1_boxes()"; 485 486 typedef dlib::array<array2d<unsigned char> > grayscale_image_array_type; 487 grayscale_image_array_type images; 488 std::vector<std::vector<rectangle> > object_locations; 489 make_simple_test_data(images, object_locations); 490 491 typedef hashed_feature_image<hog_image<3,3,1,4,hog_signed_gradient,hog_full_interpolation> > feature_extractor_type; 492 typedef scan_image_boxes<feature_extractor_type> image_scanner_type; 493 image_scanner_type scanner; 494 setup_hashed_features(scanner, images, 9); 495 use_uniform_feature_weights(scanner); 496 structural_object_detection_trainer<image_scanner_type> trainer(scanner); 497 trainer.set_num_threads(4); 498 trainer.set_overlap_tester(test_box_overlap(0,0)); 499 object_detector<image_scanner_type> detector = trainer.train(images, object_locations); 500 501 matrix<double> res = test_object_detection_function(detector, images, object_locations); 502 dlog << LINFO << "Test detector (precision,recall): " << res; 503 DLIB_TEST(sum(res) == 3); 504 505 { 506 ostringstream sout; 507 serialize(detector, sout); 508 istringstream sin(sout.str()); 509 object_detector<image_scanner_type> d2; 510 deserialize(d2, sin); 511 matrix<double> res = test_object_detection_function(d2, images, object_locations); 512 dlog << LINFO << "Test detector (precision,recall): " << res; 513 DLIB_TEST(sum(res) == 3); 514 515 validate_some_object_detector_stuff(images, detector); 516 } 517 } 518 519 // ---------------------------------------------------------------------------------------- 520 test_1m()521 void test_1m ( 522 ) 523 { 524 print_spinner(); 525 dlog << LINFO << "test_1m()"; 526 527 typedef dlib::array<array2d<unsigned char> > grayscale_image_array_type; 528 grayscale_image_array_type images; 529 std::vector<std::vector<full_object_detection> > object_locations; 530 make_simple_test_data(images, object_locations); 531 532 typedef hashed_feature_image<hog_image<3,3,1,4,hog_signed_gradient,hog_full_interpolation> > feature_extractor_type; 533 typedef scan_image_pyramid<pyramid_down<2>, feature_extractor_type> image_scanner_type; 534 image_scanner_type scanner; 535 const rectangle object_box = compute_box_dimensions(1,35*35); 536 std::vector<rectangle> mboxes; 537 const int mbox_size = 20; 538 mboxes.push_back(centered_rect(0,0, mbox_size,mbox_size)); 539 mboxes.push_back(centered_rect(0,0, mbox_size,mbox_size)); 540 mboxes.push_back(centered_rect(0,0, mbox_size,mbox_size)); 541 mboxes.push_back(centered_rect(0,0, mbox_size,mbox_size)); 542 scanner.add_detection_template(object_box, create_grid_detection_template(object_box,1,1), mboxes); 543 setup_hashed_features(scanner, images, 9); 544 use_uniform_feature_weights(scanner); 545 structural_object_detection_trainer<image_scanner_type> trainer(scanner); 546 trainer.set_num_threads(4); 547 trainer.set_overlap_tester(test_box_overlap(0,0)); 548 object_detector<image_scanner_type> detector = trainer.train(images, object_locations); 549 550 matrix<double> res = test_object_detection_function(detector, images, object_locations); 551 dlog << LINFO << "Test detector (precision,recall): " << res; 552 DLIB_TEST(sum(res) == 3); 553 554 { 555 ostringstream sout; 556 serialize(detector, sout); 557 istringstream sin(sout.str()); 558 object_detector<image_scanner_type> d2; 559 deserialize(d2, sin); 560 matrix<double> res = test_object_detection_function(d2, images, object_locations); 561 dlog << LINFO << "Test detector (precision,recall): " << res; 562 DLIB_TEST(sum(res) == 3); 563 564 validate_some_object_detector_stuff(images, detector); 565 } 566 } 567 568 // ---------------------------------------------------------------------------------------- 569 test_1_fine_hog()570 void test_1_fine_hog ( 571 ) 572 { 573 print_spinner(); 574 dlog << LINFO << "test_1_fine_hog()"; 575 576 typedef dlib::array<array2d<unsigned char> > grayscale_image_array_type; 577 grayscale_image_array_type images; 578 std::vector<std::vector<rectangle> > object_locations; 579 make_simple_test_data(images, object_locations); 580 581 typedef hashed_feature_image<fine_hog_image<3,3,2,4,hog_signed_gradient> > feature_extractor_type; 582 typedef scan_image_pyramid<pyramid_down<2>, feature_extractor_type> image_scanner_type; 583 image_scanner_type scanner; 584 const rectangle object_box = compute_box_dimensions(1,35*35); 585 scanner.add_detection_template(object_box, create_grid_detection_template(object_box,2,2)); 586 setup_hashed_features(scanner, images, 9); 587 use_uniform_feature_weights(scanner); 588 structural_object_detection_trainer<image_scanner_type> trainer(scanner); 589 trainer.set_num_threads(4); 590 trainer.set_overlap_tester(test_box_overlap(0,0)); 591 object_detector<image_scanner_type> detector = trainer.train(images, object_locations); 592 593 matrix<double> res = test_object_detection_function(detector, images, object_locations); 594 dlog << LINFO << "Test detector (precision,recall): " << res; 595 DLIB_TEST(sum(res) == 3); 596 597 { 598 ostringstream sout; 599 serialize(detector, sout); 600 istringstream sin(sout.str()); 601 object_detector<image_scanner_type> d2; 602 deserialize(d2, sin); 603 matrix<double> res = test_object_detection_function(d2, images, object_locations); 604 dlog << LINFO << "Test detector (precision,recall): " << res; 605 DLIB_TEST(sum(res) == 3); 606 607 validate_some_object_detector_stuff(images, detector); 608 } 609 } 610 611 // ---------------------------------------------------------------------------------------- 612 test_1_poly()613 void test_1_poly ( 614 ) 615 { 616 print_spinner(); 617 dlog << LINFO << "test_1_poly()"; 618 619 typedef dlib::array<array2d<unsigned char> > grayscale_image_array_type; 620 grayscale_image_array_type images; 621 std::vector<std::vector<rectangle> > object_locations; 622 make_simple_test_data(images, object_locations); 623 624 typedef hashed_feature_image<poly_image<2> > feature_extractor_type; 625 typedef scan_image_pyramid<pyramid_down<2>, feature_extractor_type> image_scanner_type; 626 image_scanner_type scanner; 627 const rectangle object_box = compute_box_dimensions(1,35*35); 628 scanner.add_detection_template(object_box, create_grid_detection_template(object_box,2,2)); 629 setup_hashed_features(scanner, images, 9); 630 use_uniform_feature_weights(scanner); 631 structural_object_detection_trainer<image_scanner_type> trainer(scanner); 632 trainer.set_num_threads(4); 633 trainer.set_overlap_tester(test_box_overlap(0,0)); 634 object_detector<image_scanner_type> detector = trainer.train(images, object_locations); 635 636 matrix<double> res = test_object_detection_function(detector, images, object_locations); 637 dlog << LINFO << "Test detector (precision,recall): " << res; 638 DLIB_TEST(sum(res) == 3); 639 640 { 641 ostringstream sout; 642 serialize(detector, sout); 643 istringstream sin(sout.str()); 644 object_detector<image_scanner_type> d2; 645 deserialize(d2, sin); 646 matrix<double> res = test_object_detection_function(d2, images, object_locations); 647 dlog << LINFO << "Test detector (precision,recall): " << res; 648 DLIB_TEST(sum(res) == 3); 649 650 validate_some_object_detector_stuff(images, detector); 651 } 652 } 653 654 // ---------------------------------------------------------------------------------------- 655 test_1m_poly()656 void test_1m_poly ( 657 ) 658 { 659 print_spinner(); 660 dlog << LINFO << "test_1_poly()"; 661 662 typedef dlib::array<array2d<unsigned char> > grayscale_image_array_type; 663 grayscale_image_array_type images; 664 std::vector<std::vector<full_object_detection> > object_locations; 665 make_simple_test_data(images, object_locations); 666 667 typedef hashed_feature_image<poly_image<2> > feature_extractor_type; 668 typedef scan_image_pyramid<pyramid_down<3>, feature_extractor_type> image_scanner_type; 669 image_scanner_type scanner; 670 const rectangle object_box = compute_box_dimensions(1,35*35); 671 std::vector<rectangle> mboxes; 672 const int mbox_size = 20; 673 mboxes.push_back(centered_rect(0,0, mbox_size,mbox_size)); 674 mboxes.push_back(centered_rect(0,0, mbox_size,mbox_size)); 675 mboxes.push_back(centered_rect(0,0, mbox_size,mbox_size)); 676 mboxes.push_back(centered_rect(0,0, mbox_size,mbox_size)); 677 scanner.add_detection_template(object_box, create_grid_detection_template(object_box,2,2), mboxes); 678 setup_hashed_features(scanner, images, 9); 679 use_uniform_feature_weights(scanner); 680 structural_object_detection_trainer<image_scanner_type> trainer(scanner); 681 trainer.set_num_threads(4); 682 trainer.set_overlap_tester(test_box_overlap(0,0)); 683 object_detector<image_scanner_type> detector = trainer.train(images, object_locations); 684 685 matrix<double> res = test_object_detection_function(detector, images, object_locations); 686 dlog << LINFO << "Test detector (precision,recall): " << res; 687 DLIB_TEST(sum(res) == 3); 688 689 { 690 ostringstream sout; 691 serialize(detector, sout); 692 istringstream sin(sout.str()); 693 object_detector<image_scanner_type> d2; 694 deserialize(d2, sin); 695 matrix<double> res = test_object_detection_function(d2, images, object_locations); 696 dlog << LINFO << "Test detector (precision,recall): " << res; 697 DLIB_TEST(sum(res) == 3); 698 699 validate_some_object_detector_stuff(images, detector); 700 } 701 } 702 703 // ---------------------------------------------------------------------------------------- 704 test_1_poly_nn()705 void test_1_poly_nn ( 706 ) 707 { 708 print_spinner(); 709 dlog << LINFO << "test_1_poly_nn()"; 710 711 typedef dlib::array<array2d<unsigned char> > grayscale_image_array_type; 712 grayscale_image_array_type images; 713 std::vector<std::vector<rectangle> > object_locations; 714 make_simple_test_data(images, object_locations); 715 716 typedef nearest_neighbor_feature_image<poly_image<5> > feature_extractor_type; 717 typedef scan_image_pyramid<pyramid_down<2>, feature_extractor_type> image_scanner_type; 718 image_scanner_type scanner; 719 720 setup_grid_detection_templates(scanner, object_locations, 2, 2); 721 feature_extractor_type nnfe; 722 pyramid_down<2> pyr_down; 723 poly_image<5> polyi; 724 nnfe.set_basis(randomly_sample_image_features(images, pyr_down, polyi, 80)); 725 scanner.copy_configuration(nnfe); 726 727 structural_object_detection_trainer<image_scanner_type> trainer(scanner); 728 trainer.set_num_threads(4); 729 object_detector<image_scanner_type> detector = trainer.train(images, object_locations); 730 731 matrix<double> res = test_object_detection_function(detector, images, object_locations); 732 dlog << LINFO << "Test detector (precision,recall): " << res; 733 DLIB_TEST(sum(res) == 3); 734 735 { 736 ostringstream sout; 737 serialize(detector, sout); 738 istringstream sin(sout.str()); 739 object_detector<image_scanner_type> d2; 740 deserialize(d2, sin); 741 matrix<double> res = test_object_detection_function(d2, images, object_locations); 742 dlog << LINFO << "Test detector (precision,recall): " << res; 743 DLIB_TEST(sum(res) == 3); 744 745 validate_some_object_detector_stuff(images, detector); 746 } 747 } 748 749 // ---------------------------------------------------------------------------------------- 750 test_1_poly_nn_boxes()751 void test_1_poly_nn_boxes ( 752 ) 753 { 754 print_spinner(); 755 dlog << LINFO << "test_1_poly_nn_boxes()"; 756 757 typedef dlib::array<array2d<unsigned char> > grayscale_image_array_type; 758 grayscale_image_array_type images; 759 std::vector<std::vector<rectangle> > object_locations; 760 make_simple_test_data(images, object_locations); 761 762 typedef nearest_neighbor_feature_image<poly_image<5> > feature_extractor_type; 763 typedef scan_image_boxes<feature_extractor_type> image_scanner_type; 764 image_scanner_type scanner; 765 766 feature_extractor_type nnfe; 767 pyramid_down<2> pyr_down; 768 poly_image<5> polyi; 769 nnfe.set_basis(randomly_sample_image_features(images, pyr_down, polyi, 80)); 770 scanner.copy_configuration(nnfe); 771 772 structural_object_detection_trainer<image_scanner_type> trainer(scanner); 773 trainer.set_num_threads(4); 774 object_detector<image_scanner_type> detector = trainer.train(images, object_locations); 775 776 matrix<double> res = test_object_detection_function(detector, images, object_locations); 777 dlog << LINFO << "Test detector (precision,recall): " << res; 778 DLIB_TEST(sum(res) == 3); 779 780 { 781 ostringstream sout; 782 serialize(detector, sout); 783 istringstream sin(sout.str()); 784 object_detector<image_scanner_type> d2; 785 deserialize(d2, sin); 786 matrix<double> res = test_object_detection_function(d2, images, object_locations); 787 dlog << LINFO << "Test detector (precision,recall): " << res; 788 DLIB_TEST(sum(res) == 3); 789 790 validate_some_object_detector_stuff(images, detector); 791 } 792 } 793 794 // ---------------------------------------------------------------------------------------- 795 test_2()796 void test_2 ( 797 ) 798 { 799 print_spinner(); 800 dlog << LINFO << "test_2()"; 801 802 typedef dlib::array<array2d<unsigned char> > grayscale_image_array_type; 803 grayscale_image_array_type images; 804 std::vector<std::vector<rectangle> > object_locations; 805 make_simple_test_data(images, object_locations); 806 807 typedef scan_image_pyramid<pyramid_down<5>, very_simple_feature_extractor> image_scanner_type; 808 image_scanner_type scanner; 809 const rectangle object_box = compute_box_dimensions(1,70*70); 810 scanner.add_detection_template(object_box, create_grid_detection_template(object_box,2,2)); 811 scanner.set_max_pyramid_levels(1); 812 structural_object_detection_trainer<image_scanner_type> trainer(scanner); 813 trainer.set_num_threads(0); 814 object_detector<image_scanner_type> detector = trainer.train(images, object_locations); 815 816 matrix<double> res = test_object_detection_function(detector, images, object_locations); 817 dlog << LINFO << "Test detector (precision,recall): " << res; 818 DLIB_TEST(sum(res) == 3); 819 820 res = cross_validate_object_detection_trainer(trainer, images, object_locations, 3); 821 dlog << LINFO << "3-fold cross validation (precision,recall): " << res; 822 DLIB_TEST(sum(res) == 3); 823 824 { 825 ostringstream sout; 826 serialize(detector, sout); 827 istringstream sin(sout.str()); 828 object_detector<image_scanner_type> d2; 829 deserialize(d2, sin); 830 matrix<double> res = test_object_detection_function(d2, images, object_locations); 831 dlog << LINFO << "Test detector (precision,recall): " << res; 832 DLIB_TEST(sum(res) == 3); 833 validate_some_object_detector_stuff(images, detector); 834 } 835 } 836 837 // ---------------------------------------------------------------------------------------- 838 839 class pyramid_down_funny : noncopyable 840 { 841 pyramid_down<2> pyr; 842 public: 843 844 template <typename T> point_down(const dlib::vector<T,2> & p) const845 dlib::vector<double,2> point_down ( const dlib::vector<T,2>& p) const { return pyr.point_down(p); } 846 847 template <typename T> point_up(const dlib::vector<T,2> & p) const848 dlib::vector<double,2> point_up ( const dlib::vector<T,2>& p) const { return pyr.point_up(p); } 849 850 template <typename T> point_down(const dlib::vector<T,2> & p,unsigned int levels) const851 dlib::vector<double,2> point_down ( const dlib::vector<T,2>& p, unsigned int levels) const { return pyr.point_down(p,levels); } 852 853 template <typename T> point_up(const dlib::vector<T,2> & p,unsigned int levels) const854 dlib::vector<double,2> point_up ( const dlib::vector<T,2>& p, unsigned int levels) const { return pyr.point_up(p,levels); } 855 rect_up(const rectangle & rect) const856 rectangle rect_up ( const rectangle& rect) const { return pyr.rect_up(rect); } 857 rect_up(const rectangle & rect,unsigned int levels) const858 rectangle rect_up ( const rectangle& rect, unsigned int levels) const { return pyr.rect_up(rect,levels); } 859 rect_down(const rectangle & rect) const860 rectangle rect_down ( const rectangle& rect) const { return pyr.rect_down(rect); } 861 rect_down(const rectangle & rect,unsigned int levels) const862 rectangle rect_down ( const rectangle& rect, unsigned int levels) const { return pyr.rect_down(rect,levels); } 863 864 template < 865 typename in_image_type, 866 typename out_image_type 867 > operator ()(const in_image_type & original,out_image_type & down) const868 void operator() ( 869 const in_image_type& original, 870 out_image_type& down 871 ) const 872 { 873 pyr(original.img, down.img); 874 } 875 876 }; 877 878 // make sure everything works even when the image isn't a dlib::array2d. 879 // So test with funny_image. test_3()880 void test_3 ( 881 ) 882 { 883 print_spinner(); 884 dlog << LINFO << "test_3()"; 885 886 887 typedef dlib::array<array2d<unsigned char> > grayscale_image_array_type; 888 typedef dlib::array<funny_image> funny_image_array_type; 889 grayscale_image_array_type images_temp; 890 funny_image_array_type images; 891 std::vector<std::vector<rectangle> > object_locations; 892 make_simple_test_data(images_temp, object_locations); 893 images.resize(images_temp.size()); 894 for (unsigned long i = 0; i < images_temp.size(); ++i) 895 { 896 images[i].img.swap(images_temp[i]); 897 } 898 899 typedef scan_image_pyramid<pyramid_down_funny, very_simple_feature_extractor> image_scanner_type; 900 image_scanner_type scanner; 901 const rectangle object_box = compute_box_dimensions(1,70*70); 902 scanner.add_detection_template(object_box, create_grid_detection_template(object_box,2,2)); 903 scanner.set_max_pyramid_levels(1); 904 structural_object_detection_trainer<image_scanner_type> trainer(scanner); 905 trainer.set_num_threads(4); 906 object_detector<image_scanner_type> detector = trainer.train(images, object_locations); 907 908 matrix<double> res = test_object_detection_function(detector, images, object_locations); 909 dlog << LINFO << "Test detector (precision,recall): " << res; 910 DLIB_TEST(sum(res) == 3); 911 912 res = cross_validate_object_detection_trainer(trainer, images, object_locations, 3); 913 dlog << LINFO << "3-fold cross validation (precision,recall): " << res; 914 DLIB_TEST(sum(res) == 3); 915 916 { 917 ostringstream sout; 918 serialize(detector, sout); 919 istringstream sin(sout.str()); 920 object_detector<image_scanner_type> d2; 921 deserialize(d2, sin); 922 matrix<double> res = test_object_detection_function(d2, images, object_locations); 923 dlog << LINFO << "Test detector (precision,recall): " << res; 924 DLIB_TEST(sum(res) == 3); 925 } 926 } 927 928 // ---------------------------------------------------------------------------------------- 929 930 class funny_box_generator 931 { 932 public: 933 template <typename image_type> operator ()(const image_type & img,std::vector<rectangle> & rects) const934 void operator() ( 935 const image_type& img, 936 std::vector<rectangle>& rects 937 ) const 938 { 939 rects.clear(); 940 find_candidate_object_locations(img.img, rects); 941 dlog << LINFO << "funny_box_generator, rects.size(): "<< rects.size(); 942 } 943 }; 944 serialize(const funny_box_generator &,std::ostream &)945 inline void serialize(const funny_box_generator&, std::ostream& ) {} deserialize(funny_box_generator &,std::istream &)946 inline void deserialize(funny_box_generator&, std::istream& ) {} 947 948 949 // make sure everything works even when the image isn't a dlib::array2d. 950 // So test with funny_image. test_3_boxes()951 void test_3_boxes ( 952 ) 953 { 954 print_spinner(); 955 dlog << LINFO << "test_3_boxes()"; 956 957 958 typedef dlib::array<array2d<unsigned char> > grayscale_image_array_type; 959 typedef dlib::array<funny_image> funny_image_array_type; 960 grayscale_image_array_type images_temp; 961 funny_image_array_type images; 962 std::vector<std::vector<rectangle> > object_locations; 963 make_simple_test_data(images_temp, object_locations); 964 images.resize(images_temp.size()); 965 for (unsigned long i = 0; i < images_temp.size(); ++i) 966 { 967 images[i].img.swap(images_temp[i]); 968 } 969 970 typedef scan_image_boxes<very_simple_feature_extractor, funny_box_generator> image_scanner_type; 971 image_scanner_type scanner; 972 structural_object_detection_trainer<image_scanner_type> trainer(scanner); 973 trainer.set_num_threads(4); 974 object_detector<image_scanner_type> detector = trainer.train(images, object_locations); 975 976 matrix<double> res = test_object_detection_function(detector, images, object_locations); 977 dlog << LINFO << "Test detector (precision,recall): " << res; 978 DLIB_TEST(sum(res) == 3); 979 980 res = cross_validate_object_detection_trainer(trainer, images, object_locations, 3); 981 dlog << LINFO << "3-fold cross validation (precision,recall): " << res; 982 DLIB_TEST(sum(res) == 3); 983 984 { 985 ostringstream sout; 986 serialize(detector, sout); 987 istringstream sin(sout.str()); 988 object_detector<image_scanner_type> d2; 989 deserialize(d2, sin); 990 matrix<double> res = test_object_detection_function(d2, images, object_locations); 991 dlog << LINFO << "Test detector (precision,recall): " << res; 992 DLIB_TEST(sum(res) == 3); 993 } 994 } 995 996 // ---------------------------------------------------------------------------------------- 997 998 class object_detector_tester : public tester 999 { 1000 public: object_detector_tester()1001 object_detector_tester ( 1002 ) : 1003 tester ("test_object_detector", 1004 "Runs tests on the structural object detection stuff.") 1005 {} 1006 perform_test()1007 void perform_test ( 1008 ) 1009 { 1010 test_fhog_pyramid(); 1011 test_1_boxes(); 1012 test_1_poly_nn_boxes(); 1013 test_3_boxes(); 1014 1015 test_1(); 1016 test_1m(); 1017 test_1_fine_hog(); 1018 test_1_poly(); 1019 test_1m_poly(); 1020 test_1_poly_nn(); 1021 test_2(); 1022 test_3(); 1023 } 1024 } a; 1025 1026 } 1027 1028 1029