1 // Copyright (C) 2011 Davis E. King (davis@dlib.net) 2 // License: Boost Software License See LICENSE.txt for the full license. 3 #ifndef DLIB_OBJECT_DeTECTOR_Hh_ 4 #define DLIB_OBJECT_DeTECTOR_Hh_ 5 6 #include "object_detector_abstract.h" 7 #include "../geometry.h" 8 #include <vector> 9 #include "box_overlap_testing.h" 10 #include "full_object_detection.h" 11 12 namespace dlib 13 { 14 15 // ---------------------------------------------------------------------------------------- 16 17 struct rect_detection 18 { 19 double detection_confidence; 20 unsigned long weight_index; 21 rectangle rect; 22 23 bool operator<(const rect_detection& item) const { return detection_confidence < item.detection_confidence; } 24 }; 25 26 struct full_detection 27 { 28 double detection_confidence; 29 unsigned long weight_index; 30 full_object_detection rect; 31 32 bool operator<(const full_detection& item) const { return detection_confidence < item.detection_confidence; } 33 }; 34 35 // ---------------------------------------------------------------------------------------- 36 37 template <typename image_scanner_type> 38 struct processed_weight_vector 39 { processed_weight_vectorprocessed_weight_vector40 processed_weight_vector(){} 41 42 typedef typename image_scanner_type::feature_vector_type feature_vector_type; 43 initprocessed_weight_vector44 void init ( 45 const image_scanner_type& 46 ) 47 /*! 48 requires 49 - w has already been assigned its value. Note that the point of this 50 function is to allow an image scanner to overload the 51 processed_weight_vector template and provide some different kind of 52 object as the output of get_detect_argument(). For example, the 53 scan_fhog_pyramid object uses an overload that causes 54 get_detect_argument() to return the special fhog_filterbank object 55 instead of a feature_vector_type. This avoids needing to construct the 56 fhog_filterbank during each call to detect and therefore speeds up 57 detection. 58 !*/ 59 {} 60 61 // return the first argument to image_scanner_type::detect() get_detect_argumentprocessed_weight_vector62 const feature_vector_type& get_detect_argument() const { return w; } 63 64 feature_vector_type w; 65 }; 66 67 // ---------------------------------------------------------------------------------------- 68 69 template < 70 typename image_scanner_type_ 71 > 72 class object_detector 73 { 74 public: 75 typedef image_scanner_type_ image_scanner_type; 76 typedef typename image_scanner_type::feature_vector_type feature_vector_type; 77 78 object_detector ( 79 ); 80 81 object_detector ( 82 const object_detector& item 83 ); 84 85 object_detector ( 86 const image_scanner_type& scanner_, 87 const test_box_overlap& overlap_tester_, 88 const feature_vector_type& w_ 89 ); 90 91 object_detector ( 92 const image_scanner_type& scanner_, 93 const test_box_overlap& overlap_tester_, 94 const std::vector<feature_vector_type>& w_ 95 ); 96 97 explicit object_detector ( 98 const std::vector<object_detector>& detectors 99 ); 100 num_detectors()101 unsigned long num_detectors ( 102 ) const { return w.size(); } 103 104 const feature_vector_type& get_w ( 105 unsigned long idx = 0 106 ) const { return w[idx].w; } 107 108 const processed_weight_vector<image_scanner_type>& get_processed_w ( 109 unsigned long idx = 0 110 ) const { return w[idx]; } 111 112 const test_box_overlap& get_overlap_tester ( 113 ) const; 114 115 const image_scanner_type& get_scanner ( 116 ) const; 117 118 object_detector& operator= ( 119 const object_detector& item 120 ); 121 122 template < 123 typename image_type 124 > 125 std::vector<rectangle> operator() ( 126 const image_type& img, 127 double adjust_threshold = 0 128 ); 129 130 template < 131 typename image_type 132 > 133 void operator() ( 134 const image_type& img, 135 std::vector<std::pair<double, rectangle> >& final_dets, 136 double adjust_threshold = 0 137 ); 138 139 template < 140 typename image_type 141 > 142 void operator() ( 143 const image_type& img, 144 std::vector<std::pair<double, full_object_detection> >& final_dets, 145 double adjust_threshold = 0 146 ); 147 148 template < 149 typename image_type 150 > 151 void operator() ( 152 const image_type& img, 153 std::vector<full_object_detection>& final_dets, 154 double adjust_threshold = 0 155 ); 156 157 // These typedefs are here for backwards compatibility with previous versions of 158 // dlib. 159 typedef ::dlib::rect_detection rect_detection; 160 typedef ::dlib::full_detection full_detection; 161 162 template < 163 typename image_type 164 > 165 void operator() ( 166 const image_type& img, 167 std::vector<rect_detection>& final_dets, 168 double adjust_threshold = 0 169 ); 170 171 template < 172 typename image_type 173 > 174 void operator() ( 175 const image_type& img, 176 std::vector<full_detection>& final_dets, 177 double adjust_threshold = 0 178 ); 179 180 template <typename T> 181 friend void serialize ( 182 const object_detector<T>& item, 183 std::ostream& out 184 ); 185 186 template <typename T> 187 friend void deserialize ( 188 object_detector<T>& item, 189 std::istream& in 190 ); 191 192 private: 193 overlaps_any_box(const std::vector<rect_detection> & rects,const dlib::rectangle & rect)194 bool overlaps_any_box ( 195 const std::vector<rect_detection>& rects, 196 const dlib::rectangle& rect 197 ) const 198 { 199 for (unsigned long i = 0; i < rects.size(); ++i) 200 { 201 if (boxes_overlap(rects[i].rect, rect)) 202 return true; 203 } 204 return false; 205 } 206 207 test_box_overlap boxes_overlap; 208 std::vector<processed_weight_vector<image_scanner_type> > w; 209 image_scanner_type scanner; 210 }; 211 212 // ---------------------------------------------------------------------------------------- 213 214 template <typename T> serialize(const object_detector<T> & item,std::ostream & out)215 void serialize ( 216 const object_detector<T>& item, 217 std::ostream& out 218 ) 219 { 220 int version = 2; 221 serialize(version, out); 222 223 T scanner; 224 scanner.copy_configuration(item.scanner); 225 serialize(scanner, out); 226 serialize(item.boxes_overlap, out); 227 // serialize all the weight vectors 228 serialize(item.w.size(), out); 229 for (unsigned long i = 0; i < item.w.size(); ++i) 230 serialize(item.w[i].w, out); 231 } 232 233 // ---------------------------------------------------------------------------------------- 234 235 template <typename T> deserialize(object_detector<T> & item,std::istream & in)236 void deserialize ( 237 object_detector<T>& item, 238 std::istream& in 239 ) 240 { 241 int version = 0; 242 deserialize(version, in); 243 if (version == 1) 244 { 245 deserialize(item.scanner, in); 246 item.w.resize(1); 247 deserialize(item.w[0].w, in); 248 item.w[0].init(item.scanner); 249 deserialize(item.boxes_overlap, in); 250 } 251 else if (version == 2) 252 { 253 deserialize(item.scanner, in); 254 deserialize(item.boxes_overlap, in); 255 unsigned long num_detectors = 0; 256 deserialize(num_detectors, in); 257 item.w.resize(num_detectors); 258 for (unsigned long i = 0; i < item.w.size(); ++i) 259 { 260 deserialize(item.w[i].w, in); 261 item.w[i].init(item.scanner); 262 } 263 } 264 else 265 { 266 throw serialization_error("Unexpected version encountered while deserializing a dlib::object_detector object."); 267 } 268 } 269 270 // ---------------------------------------------------------------------------------------- 271 // ---------------------------------------------------------------------------------------- 272 // object_detector member functions 273 // ---------------------------------------------------------------------------------------- 274 // ---------------------------------------------------------------------------------------- 275 276 template < 277 typename image_scanner_type 278 > 279 object_detector<image_scanner_type>:: object_detector()280 object_detector ( 281 ) 282 { 283 } 284 285 // ---------------------------------------------------------------------------------------- 286 287 template < 288 typename image_scanner_type 289 > 290 object_detector<image_scanner_type>:: object_detector(const object_detector & item)291 object_detector ( 292 const object_detector& item 293 ) 294 { 295 boxes_overlap = item.boxes_overlap; 296 w = item.w; 297 scanner.copy_configuration(item.scanner); 298 } 299 300 // ---------------------------------------------------------------------------------------- 301 302 template < 303 typename image_scanner_type 304 > 305 object_detector<image_scanner_type>:: object_detector(const image_scanner_type & scanner_,const test_box_overlap & overlap_tester,const feature_vector_type & w_)306 object_detector ( 307 const image_scanner_type& scanner_, 308 const test_box_overlap& overlap_tester, 309 const feature_vector_type& w_ 310 ) : 311 boxes_overlap(overlap_tester) 312 { 313 // make sure requires clause is not broken 314 DLIB_ASSERT(scanner_.get_num_detection_templates() > 0 && 315 w_.size() == scanner_.get_num_dimensions() + 1, 316 "\t object_detector::object_detector(scanner_,overlap_tester,w_)" 317 << "\n\t Invalid inputs were given to this function " 318 << "\n\t scanner_.get_num_detection_templates(): " << scanner_.get_num_detection_templates() 319 << "\n\t w_.size(): " << w_.size() 320 << "\n\t scanner_.get_num_dimensions(): " << scanner_.get_num_dimensions() 321 << "\n\t this: " << this 322 ); 323 324 scanner.copy_configuration(scanner_); 325 w.resize(1); 326 w[0].w = w_; 327 w[0].init(scanner); 328 } 329 330 // ---------------------------------------------------------------------------------------- 331 332 template < 333 typename image_scanner_type 334 > 335 object_detector<image_scanner_type>:: object_detector(const image_scanner_type & scanner_,const test_box_overlap & overlap_tester,const std::vector<feature_vector_type> & w_)336 object_detector ( 337 const image_scanner_type& scanner_, 338 const test_box_overlap& overlap_tester, 339 const std::vector<feature_vector_type>& w_ 340 ) : 341 boxes_overlap(overlap_tester) 342 { 343 // make sure requires clause is not broken 344 DLIB_CASSERT(scanner_.get_num_detection_templates() > 0 && w_.size() > 0, 345 "\t object_detector::object_detector(scanner_,overlap_tester,w_)" 346 << "\n\t Invalid inputs were given to this function " 347 << "\n\t scanner_.get_num_detection_templates(): " << scanner_.get_num_detection_templates() 348 << "\n\t w_.size(): " << w_.size() 349 << "\n\t this: " << this 350 ); 351 352 for (unsigned long i = 0; i < w_.size(); ++i) 353 { 354 DLIB_CASSERT(w_[i].size() == scanner_.get_num_dimensions() + 1, 355 "\t object_detector::object_detector(scanner_,overlap_tester,w_)" 356 << "\n\t Invalid inputs were given to this function " 357 << "\n\t scanner_.get_num_detection_templates(): " << scanner_.get_num_detection_templates() 358 << "\n\t w_["<<i<<"].size(): " << w_[i].size() 359 << "\n\t scanner_.get_num_dimensions(): " << scanner_.get_num_dimensions() 360 << "\n\t this: " << this 361 ); 362 } 363 364 scanner.copy_configuration(scanner_); 365 w.resize(w_.size()); 366 for (unsigned long i = 0; i < w.size(); ++i) 367 { 368 w[i].w = w_[i]; 369 w[i].init(scanner); 370 } 371 } 372 373 // ---------------------------------------------------------------------------------------- 374 375 template < 376 typename image_scanner_type 377 > 378 object_detector<image_scanner_type>:: object_detector(const std::vector<object_detector> & detectors)379 object_detector ( 380 const std::vector<object_detector>& detectors 381 ) 382 { 383 DLIB_CASSERT(detectors.size() != 0, 384 "\t object_detector::object_detector(detectors)" 385 << "\n\t Invalid inputs were given to this function " 386 << "\n\t this: " << this 387 ); 388 std::vector<feature_vector_type> weights; 389 weights.reserve(detectors.size()); 390 for (unsigned long i = 0; i < detectors.size(); ++i) 391 { 392 for (unsigned long j = 0; j < detectors[i].num_detectors(); ++j) 393 weights.push_back(detectors[i].get_w(j)); 394 } 395 396 *this = object_detector(detectors[0].get_scanner(), detectors[0].get_overlap_tester(), weights); 397 } 398 399 // ---------------------------------------------------------------------------------------- 400 401 template < 402 typename image_scanner_type 403 > 404 object_detector<image_scanner_type>& object_detector<image_scanner_type>:: 405 operator= ( 406 const object_detector& item 407 ) 408 { 409 if (this == &item) 410 return *this; 411 412 boxes_overlap = item.boxes_overlap; 413 w = item.w; 414 scanner.copy_configuration(item.scanner); 415 return *this; 416 } 417 418 // ---------------------------------------------------------------------------------------- 419 420 template < 421 typename image_scanner_type 422 > 423 template < 424 typename image_type 425 > 426 void object_detector<image_scanner_type>:: operator()427 operator() ( 428 const image_type& img, 429 std::vector<rect_detection>& final_dets, 430 double adjust_threshold 431 ) 432 { 433 scanner.load(img); 434 std::vector<std::pair<double, rectangle> > dets; 435 std::vector<rect_detection> dets_accum; 436 for (unsigned long i = 0; i < w.size(); ++i) 437 { 438 const double thresh = w[i].w(scanner.get_num_dimensions()); 439 scanner.detect(w[i].get_detect_argument(), dets, thresh + adjust_threshold); 440 for (unsigned long j = 0; j < dets.size(); ++j) 441 { 442 rect_detection temp; 443 temp.detection_confidence = dets[j].first-thresh; 444 temp.weight_index = i; 445 temp.rect = dets[j].second; 446 dets_accum.push_back(temp); 447 } 448 } 449 450 // Do non-max suppression 451 final_dets.clear(); 452 if (w.size() > 1) 453 std::sort(dets_accum.rbegin(), dets_accum.rend()); 454 for (unsigned long i = 0; i < dets_accum.size(); ++i) 455 { 456 if (overlaps_any_box(final_dets, dets_accum[i].rect)) 457 continue; 458 459 final_dets.push_back(dets_accum[i]); 460 } 461 } 462 463 // ---------------------------------------------------------------------------------------- 464 465 template < 466 typename image_scanner_type 467 > 468 template < 469 typename image_type 470 > 471 void object_detector<image_scanner_type>:: operator()472 operator() ( 473 const image_type& img, 474 std::vector<full_detection>& final_dets, 475 double adjust_threshold 476 ) 477 { 478 std::vector<rect_detection> dets; 479 (*this)(img,dets,adjust_threshold); 480 481 final_dets.resize(dets.size()); 482 483 // convert all the rectangle detections into full_object_detections. 484 for (unsigned long i = 0; i < dets.size(); ++i) 485 { 486 final_dets[i].detection_confidence = dets[i].detection_confidence; 487 final_dets[i].weight_index = dets[i].weight_index; 488 final_dets[i].rect = scanner.get_full_object_detection(dets[i].rect, w[dets[i].weight_index].w); 489 } 490 } 491 492 // ---------------------------------------------------------------------------------------- 493 494 template < 495 typename image_scanner_type 496 > 497 template < 498 typename image_type 499 > 500 std::vector<rectangle> object_detector<image_scanner_type>:: operator()501 operator() ( 502 const image_type& img, 503 double adjust_threshold 504 ) 505 { 506 std::vector<rect_detection> dets; 507 (*this)(img,dets,adjust_threshold); 508 509 std::vector<rectangle> final_dets(dets.size()); 510 for (unsigned long i = 0; i < dets.size(); ++i) 511 final_dets[i] = dets[i].rect; 512 513 return final_dets; 514 } 515 516 // ---------------------------------------------------------------------------------------- 517 518 template < 519 typename image_scanner_type 520 > 521 template < 522 typename image_type 523 > 524 void object_detector<image_scanner_type>:: operator()525 operator() ( 526 const image_type& img, 527 std::vector<std::pair<double, rectangle> >& final_dets, 528 double adjust_threshold 529 ) 530 { 531 std::vector<rect_detection> dets; 532 (*this)(img,dets,adjust_threshold); 533 534 final_dets.resize(dets.size()); 535 for (unsigned long i = 0; i < dets.size(); ++i) 536 final_dets[i] = std::make_pair(dets[i].detection_confidence,dets[i].rect); 537 } 538 539 // ---------------------------------------------------------------------------------------- 540 541 template < 542 typename image_scanner_type 543 > 544 template < 545 typename image_type 546 > 547 void object_detector<image_scanner_type>:: operator()548 operator() ( 549 const image_type& img, 550 std::vector<std::pair<double, full_object_detection> >& final_dets, 551 double adjust_threshold 552 ) 553 { 554 std::vector<rect_detection> dets; 555 (*this)(img,dets,adjust_threshold); 556 557 final_dets.clear(); 558 final_dets.reserve(dets.size()); 559 560 // convert all the rectangle detections into full_object_detections. 561 for (unsigned long i = 0; i < dets.size(); ++i) 562 { 563 final_dets.push_back(std::make_pair(dets[i].detection_confidence, 564 scanner.get_full_object_detection(dets[i].rect, w[dets[i].weight_index].w))); 565 } 566 } 567 568 // ---------------------------------------------------------------------------------------- 569 570 template < 571 typename image_scanner_type 572 > 573 template < 574 typename image_type 575 > 576 void object_detector<image_scanner_type>:: operator()577 operator() ( 578 const image_type& img, 579 std::vector<full_object_detection>& final_dets, 580 double adjust_threshold 581 ) 582 { 583 std::vector<rect_detection> dets; 584 (*this)(img,dets,adjust_threshold); 585 586 final_dets.clear(); 587 final_dets.reserve(dets.size()); 588 589 // convert all the rectangle detections into full_object_detections. 590 for (unsigned long i = 0; i < dets.size(); ++i) 591 { 592 final_dets.push_back(scanner.get_full_object_detection(dets[i].rect, w[dets[i].weight_index].w)); 593 } 594 } 595 596 // ---------------------------------------------------------------------------------------- 597 598 template < 599 typename image_scanner_type 600 > 601 const test_box_overlap& object_detector<image_scanner_type>:: get_overlap_tester()602 get_overlap_tester ( 603 ) const 604 { 605 return boxes_overlap; 606 } 607 608 // ---------------------------------------------------------------------------------------- 609 610 template < 611 typename image_scanner_type 612 > 613 const image_scanner_type& object_detector<image_scanner_type>:: get_scanner()614 get_scanner ( 615 ) const 616 { 617 return scanner; 618 } 619 620 // ---------------------------------------------------------------------------------------- 621 622 } 623 624 #endif // DLIB_OBJECT_DeTECTOR_Hh_ 625 626 627