1 // Copyright (C) 2011  Davis E. King (davis@dlib.net)
2 // License: Boost Software License   See LICENSE.txt for the full license.
3 #ifndef DLIB_OBJECT_DeTECTOR_Hh_
4 #define DLIB_OBJECT_DeTECTOR_Hh_
5 
6 #include "object_detector_abstract.h"
7 #include "../geometry.h"
8 #include <vector>
9 #include "box_overlap_testing.h"
10 #include "full_object_detection.h"
11 
12 namespace dlib
13 {
14 
15 // ----------------------------------------------------------------------------------------
16 
17     struct rect_detection
18     {
19         double detection_confidence;
20         unsigned long weight_index;
21         rectangle rect;
22 
23         bool operator<(const rect_detection& item) const { return detection_confidence < item.detection_confidence; }
24     };
25 
26     struct full_detection
27     {
28         double detection_confidence;
29         unsigned long weight_index;
30         full_object_detection rect;
31 
32         bool operator<(const full_detection& item) const { return detection_confidence < item.detection_confidence; }
33     };
34 
35 // ----------------------------------------------------------------------------------------
36 
37     template <typename image_scanner_type>
38     struct processed_weight_vector
39     {
processed_weight_vectorprocessed_weight_vector40         processed_weight_vector(){}
41 
42         typedef typename image_scanner_type::feature_vector_type feature_vector_type;
43 
initprocessed_weight_vector44         void init (
45             const image_scanner_type&
46         )
47         /*!
48             requires
49                 - w has already been assigned its value.  Note that the point of this
50                   function is to allow an image scanner to overload the
51                   processed_weight_vector template and provide some different kind of
52                   object as the output of get_detect_argument().  For example, the
53                   scan_fhog_pyramid object uses an overload that causes
54                   get_detect_argument() to return the special fhog_filterbank object
55                   instead of a feature_vector_type.  This avoids needing to construct the
56                   fhog_filterbank during each call to detect and therefore speeds up
57                   detection.
58         !*/
59         {}
60 
61         // return the first argument to image_scanner_type::detect()
get_detect_argumentprocessed_weight_vector62         const feature_vector_type& get_detect_argument() const { return w; }
63 
64         feature_vector_type w;
65     };
66 
67 // ----------------------------------------------------------------------------------------
68 
69     template <
70         typename image_scanner_type_
71         >
72     class object_detector
73     {
74     public:
75         typedef image_scanner_type_ image_scanner_type;
76         typedef typename image_scanner_type::feature_vector_type feature_vector_type;
77 
78         object_detector (
79         );
80 
81         object_detector (
82             const object_detector& item
83         );
84 
85         object_detector (
86             const image_scanner_type& scanner_,
87             const test_box_overlap& overlap_tester_,
88             const feature_vector_type& w_
89         );
90 
91         object_detector (
92             const image_scanner_type& scanner_,
93             const test_box_overlap& overlap_tester_,
94             const std::vector<feature_vector_type>& w_
95         );
96 
97         explicit object_detector (
98             const std::vector<object_detector>& detectors
99         );
100 
num_detectors()101         unsigned long num_detectors (
102         ) const { return w.size(); }
103 
104         const feature_vector_type& get_w (
105             unsigned long idx = 0
106         ) const { return w[idx].w; }
107 
108         const processed_weight_vector<image_scanner_type>& get_processed_w (
109             unsigned long idx = 0
110         ) const { return w[idx]; }
111 
112         const test_box_overlap& get_overlap_tester (
113         ) const;
114 
115         const image_scanner_type& get_scanner (
116         ) const;
117 
118         object_detector& operator= (
119             const object_detector& item
120         );
121 
122         template <
123             typename image_type
124             >
125         std::vector<rectangle> operator() (
126             const image_type& img,
127             double adjust_threshold = 0
128         );
129 
130         template <
131             typename image_type
132             >
133         void operator() (
134             const image_type& img,
135             std::vector<std::pair<double, rectangle> >& final_dets,
136             double adjust_threshold = 0
137         );
138 
139         template <
140             typename image_type
141             >
142         void operator() (
143             const image_type& img,
144             std::vector<std::pair<double, full_object_detection> >& final_dets,
145             double adjust_threshold = 0
146         );
147 
148         template <
149             typename image_type
150             >
151         void operator() (
152             const image_type& img,
153             std::vector<full_object_detection>& final_dets,
154             double adjust_threshold = 0
155         );
156 
157         // These typedefs are here for backwards compatibility with previous versions of
158         // dlib.
159         typedef ::dlib::rect_detection rect_detection;
160         typedef ::dlib::full_detection full_detection;
161 
162         template <
163             typename image_type
164             >
165         void operator() (
166             const image_type& img,
167             std::vector<rect_detection>& final_dets,
168             double adjust_threshold = 0
169         );
170 
171         template <
172             typename image_type
173             >
174         void operator() (
175             const image_type& img,
176             std::vector<full_detection>& final_dets,
177             double adjust_threshold = 0
178         );
179 
180         template <typename T>
181         friend void serialize (
182             const object_detector<T>& item,
183             std::ostream& out
184         );
185 
186         template <typename T>
187         friend void deserialize (
188             object_detector<T>& item,
189             std::istream& in
190         );
191 
192     private:
193 
overlaps_any_box(const std::vector<rect_detection> & rects,const dlib::rectangle & rect)194         bool overlaps_any_box (
195             const std::vector<rect_detection>& rects,
196             const dlib::rectangle& rect
197         ) const
198         {
199             for (unsigned long i = 0; i < rects.size(); ++i)
200             {
201                 if (boxes_overlap(rects[i].rect, rect))
202                     return true;
203             }
204             return false;
205         }
206 
207         test_box_overlap boxes_overlap;
208         std::vector<processed_weight_vector<image_scanner_type> > w;
209         image_scanner_type scanner;
210     };
211 
212 // ----------------------------------------------------------------------------------------
213 
214     template <typename T>
serialize(const object_detector<T> & item,std::ostream & out)215     void serialize (
216         const object_detector<T>& item,
217         std::ostream& out
218     )
219     {
220         int version = 2;
221         serialize(version, out);
222 
223         T scanner;
224         scanner.copy_configuration(item.scanner);
225         serialize(scanner, out);
226         serialize(item.boxes_overlap, out);
227         // serialize all the weight vectors
228         serialize(item.w.size(), out);
229         for (unsigned long i = 0; i < item.w.size(); ++i)
230             serialize(item.w[i].w, out);
231     }
232 
233 // ----------------------------------------------------------------------------------------
234 
235     template <typename T>
deserialize(object_detector<T> & item,std::istream & in)236     void deserialize (
237         object_detector<T>& item,
238         std::istream& in
239     )
240     {
241         int version = 0;
242         deserialize(version, in);
243         if (version == 1)
244         {
245             deserialize(item.scanner, in);
246             item.w.resize(1);
247             deserialize(item.w[0].w, in);
248             item.w[0].init(item.scanner);
249             deserialize(item.boxes_overlap, in);
250         }
251         else if (version == 2)
252         {
253             deserialize(item.scanner, in);
254             deserialize(item.boxes_overlap, in);
255             unsigned long num_detectors = 0;
256             deserialize(num_detectors, in);
257             item.w.resize(num_detectors);
258             for (unsigned long i = 0; i < item.w.size(); ++i)
259             {
260                 deserialize(item.w[i].w, in);
261                 item.w[i].init(item.scanner);
262             }
263         }
264         else
265         {
266             throw serialization_error("Unexpected version encountered while deserializing a dlib::object_detector object.");
267         }
268     }
269 
270 // ----------------------------------------------------------------------------------------
271 // ----------------------------------------------------------------------------------------
272 //                      object_detector member functions
273 // ----------------------------------------------------------------------------------------
274 // ----------------------------------------------------------------------------------------
275 
276     template <
277         typename image_scanner_type
278         >
279     object_detector<image_scanner_type>::
object_detector()280     object_detector (
281     )
282     {
283     }
284 
285 // ----------------------------------------------------------------------------------------
286 
287     template <
288         typename image_scanner_type
289         >
290     object_detector<image_scanner_type>::
object_detector(const object_detector & item)291     object_detector (
292         const object_detector& item
293     )
294     {
295         boxes_overlap = item.boxes_overlap;
296         w = item.w;
297         scanner.copy_configuration(item.scanner);
298     }
299 
300 // ----------------------------------------------------------------------------------------
301 
302     template <
303         typename image_scanner_type
304         >
305     object_detector<image_scanner_type>::
object_detector(const image_scanner_type & scanner_,const test_box_overlap & overlap_tester,const feature_vector_type & w_)306     object_detector (
307         const image_scanner_type& scanner_,
308         const test_box_overlap& overlap_tester,
309         const feature_vector_type& w_
310     ) :
311         boxes_overlap(overlap_tester)
312     {
313         // make sure requires clause is not broken
314         DLIB_ASSERT(scanner_.get_num_detection_templates() > 0 &&
315                     w_.size() == scanner_.get_num_dimensions() + 1,
316             "\t object_detector::object_detector(scanner_,overlap_tester,w_)"
317             << "\n\t Invalid inputs were given to this function "
318             << "\n\t scanner_.get_num_detection_templates(): " << scanner_.get_num_detection_templates()
319             << "\n\t w_.size():                     " << w_.size()
320             << "\n\t scanner_.get_num_dimensions(): " << scanner_.get_num_dimensions()
321             << "\n\t this: " << this
322             );
323 
324         scanner.copy_configuration(scanner_);
325         w.resize(1);
326         w[0].w = w_;
327         w[0].init(scanner);
328     }
329 
330 // ----------------------------------------------------------------------------------------
331 
332     template <
333         typename image_scanner_type
334         >
335     object_detector<image_scanner_type>::
object_detector(const image_scanner_type & scanner_,const test_box_overlap & overlap_tester,const std::vector<feature_vector_type> & w_)336     object_detector (
337         const image_scanner_type& scanner_,
338         const test_box_overlap& overlap_tester,
339         const std::vector<feature_vector_type>& w_
340     ) :
341         boxes_overlap(overlap_tester)
342     {
343         // make sure requires clause is not broken
344         DLIB_CASSERT(scanner_.get_num_detection_templates() > 0 && w_.size() > 0,
345             "\t object_detector::object_detector(scanner_,overlap_tester,w_)"
346             << "\n\t Invalid inputs were given to this function "
347             << "\n\t scanner_.get_num_detection_templates(): " << scanner_.get_num_detection_templates()
348             << "\n\t w_.size():                     " << w_.size()
349             << "\n\t this: " << this
350             );
351 
352         for (unsigned long i = 0; i < w_.size(); ++i)
353         {
354             DLIB_CASSERT(w_[i].size() == scanner_.get_num_dimensions() + 1,
355                 "\t object_detector::object_detector(scanner_,overlap_tester,w_)"
356                 << "\n\t Invalid inputs were given to this function "
357                 << "\n\t scanner_.get_num_detection_templates(): " << scanner_.get_num_detection_templates()
358                 << "\n\t w_["<<i<<"].size():                     " << w_[i].size()
359                 << "\n\t scanner_.get_num_dimensions(): " << scanner_.get_num_dimensions()
360                 << "\n\t this: " << this
361                 );
362         }
363 
364         scanner.copy_configuration(scanner_);
365         w.resize(w_.size());
366         for (unsigned long i = 0; i < w.size(); ++i)
367         {
368             w[i].w = w_[i];
369             w[i].init(scanner);
370         }
371     }
372 
373 // ----------------------------------------------------------------------------------------
374 
375     template <
376         typename image_scanner_type
377         >
378     object_detector<image_scanner_type>::
object_detector(const std::vector<object_detector> & detectors)379     object_detector (
380         const std::vector<object_detector>& detectors
381     )
382     {
383         DLIB_CASSERT(detectors.size() != 0,
384                 "\t object_detector::object_detector(detectors)"
385                 << "\n\t Invalid inputs were given to this function "
386                 << "\n\t this: " << this
387         );
388         std::vector<feature_vector_type> weights;
389         weights.reserve(detectors.size());
390         for (unsigned long i = 0; i < detectors.size(); ++i)
391         {
392             for (unsigned long j = 0; j < detectors[i].num_detectors(); ++j)
393                 weights.push_back(detectors[i].get_w(j));
394         }
395 
396         *this = object_detector(detectors[0].get_scanner(), detectors[0].get_overlap_tester(), weights);
397     }
398 
399 // ----------------------------------------------------------------------------------------
400 
401     template <
402         typename image_scanner_type
403         >
404     object_detector<image_scanner_type>& object_detector<image_scanner_type>::
405     operator= (
406         const object_detector& item
407     )
408     {
409         if (this == &item)
410             return *this;
411 
412         boxes_overlap = item.boxes_overlap;
413         w = item.w;
414         scanner.copy_configuration(item.scanner);
415         return *this;
416     }
417 
418 // ----------------------------------------------------------------------------------------
419 
420     template <
421         typename image_scanner_type
422         >
423     template <
424         typename image_type
425         >
426     void object_detector<image_scanner_type>::
operator()427     operator() (
428         const image_type& img,
429         std::vector<rect_detection>& final_dets,
430         double adjust_threshold
431     )
432     {
433         scanner.load(img);
434         std::vector<std::pair<double, rectangle> > dets;
435         std::vector<rect_detection> dets_accum;
436         for (unsigned long i = 0; i < w.size(); ++i)
437         {
438             const double thresh = w[i].w(scanner.get_num_dimensions());
439             scanner.detect(w[i].get_detect_argument(), dets, thresh + adjust_threshold);
440             for (unsigned long j = 0; j < dets.size(); ++j)
441             {
442                 rect_detection temp;
443                 temp.detection_confidence = dets[j].first-thresh;
444                 temp.weight_index = i;
445                 temp.rect = dets[j].second;
446                 dets_accum.push_back(temp);
447             }
448         }
449 
450         // Do non-max suppression
451         final_dets.clear();
452         if (w.size() > 1)
453             std::sort(dets_accum.rbegin(), dets_accum.rend());
454         for (unsigned long i = 0; i < dets_accum.size(); ++i)
455         {
456             if (overlaps_any_box(final_dets, dets_accum[i].rect))
457                 continue;
458 
459             final_dets.push_back(dets_accum[i]);
460         }
461     }
462 
463 // ----------------------------------------------------------------------------------------
464 
465     template <
466         typename image_scanner_type
467         >
468     template <
469         typename image_type
470         >
471     void object_detector<image_scanner_type>::
operator()472     operator() (
473         const image_type& img,
474         std::vector<full_detection>& final_dets,
475         double adjust_threshold
476     )
477     {
478         std::vector<rect_detection> dets;
479         (*this)(img,dets,adjust_threshold);
480 
481         final_dets.resize(dets.size());
482 
483         // convert all the rectangle detections into full_object_detections.
484         for (unsigned long i = 0; i < dets.size(); ++i)
485         {
486             final_dets[i].detection_confidence = dets[i].detection_confidence;
487             final_dets[i].weight_index = dets[i].weight_index;
488             final_dets[i].rect = scanner.get_full_object_detection(dets[i].rect, w[dets[i].weight_index].w);
489         }
490     }
491 
492 // ----------------------------------------------------------------------------------------
493 
494     template <
495         typename image_scanner_type
496         >
497     template <
498         typename image_type
499         >
500     std::vector<rectangle> object_detector<image_scanner_type>::
operator()501     operator() (
502         const image_type& img,
503         double adjust_threshold
504     )
505     {
506         std::vector<rect_detection> dets;
507         (*this)(img,dets,adjust_threshold);
508 
509         std::vector<rectangle> final_dets(dets.size());
510         for (unsigned long i = 0; i < dets.size(); ++i)
511             final_dets[i] = dets[i].rect;
512 
513         return final_dets;
514     }
515 
516 // ----------------------------------------------------------------------------------------
517 
518     template <
519         typename image_scanner_type
520         >
521     template <
522         typename image_type
523         >
524     void object_detector<image_scanner_type>::
operator()525     operator() (
526         const image_type& img,
527         std::vector<std::pair<double, rectangle> >& final_dets,
528         double adjust_threshold
529     )
530     {
531         std::vector<rect_detection> dets;
532         (*this)(img,dets,adjust_threshold);
533 
534         final_dets.resize(dets.size());
535         for (unsigned long i = 0; i < dets.size(); ++i)
536             final_dets[i] = std::make_pair(dets[i].detection_confidence,dets[i].rect);
537     }
538 
539 // ----------------------------------------------------------------------------------------
540 
541     template <
542         typename image_scanner_type
543         >
544     template <
545         typename image_type
546         >
547     void object_detector<image_scanner_type>::
operator()548     operator() (
549         const image_type& img,
550         std::vector<std::pair<double, full_object_detection> >& final_dets,
551         double adjust_threshold
552     )
553     {
554         std::vector<rect_detection> dets;
555         (*this)(img,dets,adjust_threshold);
556 
557         final_dets.clear();
558         final_dets.reserve(dets.size());
559 
560         // convert all the rectangle detections into full_object_detections.
561         for (unsigned long i = 0; i < dets.size(); ++i)
562         {
563             final_dets.push_back(std::make_pair(dets[i].detection_confidence,
564                                                 scanner.get_full_object_detection(dets[i].rect, w[dets[i].weight_index].w)));
565         }
566     }
567 
568 // ----------------------------------------------------------------------------------------
569 
570     template <
571         typename image_scanner_type
572         >
573     template <
574         typename image_type
575         >
576     void object_detector<image_scanner_type>::
operator()577     operator() (
578         const image_type& img,
579         std::vector<full_object_detection>& final_dets,
580         double adjust_threshold
581     )
582     {
583         std::vector<rect_detection> dets;
584         (*this)(img,dets,adjust_threshold);
585 
586         final_dets.clear();
587         final_dets.reserve(dets.size());
588 
589         // convert all the rectangle detections into full_object_detections.
590         for (unsigned long i = 0; i < dets.size(); ++i)
591         {
592             final_dets.push_back(scanner.get_full_object_detection(dets[i].rect, w[dets[i].weight_index].w));
593         }
594     }
595 
596 // ----------------------------------------------------------------------------------------
597 
598     template <
599         typename image_scanner_type
600         >
601     const test_box_overlap& object_detector<image_scanner_type>::
get_overlap_tester()602     get_overlap_tester (
603     ) const
604     {
605         return boxes_overlap;
606     }
607 
608 // ----------------------------------------------------------------------------------------
609 
610     template <
611         typename image_scanner_type
612         >
613     const image_scanner_type& object_detector<image_scanner_type>::
get_scanner()614     get_scanner (
615     ) const
616     {
617         return scanner;
618     }
619 
620 // ----------------------------------------------------------------------------------------
621 
622 }
623 
624 #endif // DLIB_OBJECT_DeTECTOR_Hh_
625 
626 
627