1 // The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
2 /*
3     This example shows how to train a instance segmentation net using the PASCAL VOC2012
4     dataset.  For an introduction to what segmentation is, see the accompanying header file
5     dnn_instance_segmentation_ex.h.
6 
7     Instructions how to run the example:
8     1. Download the PASCAL VOC2012 data, and untar it somewhere.
9        http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
10     2. Build the dnn_instance_segmentation_train_ex example program.
11     3. Run:
12        ./dnn_instance_segmentation_train_ex /path/to/VOC2012
13     4. Wait while the network is being trained.
14     5. Build the dnn_instance_segmentation_ex example program.
15     6. Run:
16        ./dnn_instance_segmentation_ex /path/to/VOC2012-or-other-images
17 
18     It would be a good idea to become familiar with dlib's DNN tooling before reading this
19     example.  So you should read dnn_introduction_ex.cpp, dnn_introduction2_ex.cpp,
20     and dnn_semantic_segmentation_train_ex.cpp before reading this example program.
21 */
22 
23 #include "dnn_instance_segmentation_ex.h"
24 #include "pascal_voc_2012.h"
25 
26 #include <iostream>
27 #include <dlib/data_io.h>
28 #include <dlib/image_transforms.h>
29 #include <dlib/dir_nav.h>
30 #include <iterator>
31 #include <thread>
32 #if __cplusplus >= 201703L || (defined(_MSVC_LANG) && _MSVC_LANG >= 201703L)
33 #include <execution>
34 #endif // __cplusplus >= 201703L
35 
36 using namespace std;
37 using namespace dlib;
38 
39 // ----------------------------------------------------------------------------------------
40 
41 // A single training sample for detection. A mini-batch comprises many of these.
42 struct det_training_sample
43 {
44     matrix<rgb_pixel> input_image;
45     std::vector<dlib::mmod_rect> mmod_rects;
46 };
47 
48 // A single training sample for segmentation. A mini-batch comprises many of these.
49 struct seg_training_sample
50 {
51     matrix<rgb_pixel> input_image;
52     matrix<float> label_image; // The ground-truth label of each pixel. (+1 or -1)
53 };
54 
55 // ----------------------------------------------------------------------------------------
56 
is_instance_pixel(const dlib::rgb_pixel & rgb_label)57 bool is_instance_pixel(const dlib::rgb_pixel& rgb_label)
58 {
59     if (rgb_label == dlib::rgb_pixel(0, 0, 0))
60         return false; // Background
61     if (rgb_label == dlib::rgb_pixel(224, 224, 192))
62         return false; // The cream-colored `void' label is used in border regions and to mask difficult objects
63 
64     return true;
65 }
66 
67 // Provide hash function for dlib::rgb_pixel
68 namespace std {
69     template <>
70     struct hash<dlib::rgb_pixel>
71     {
operator ()std::hash72         std::size_t operator()(const dlib::rgb_pixel& p) const
73         {
74             return (static_cast<uint32_t>(p.red) << 16)
75                  | (static_cast<uint32_t>(p.green) << 8)
76                  | (static_cast<uint32_t>(p.blue));
77         }
78     };
79 }
80 
81 struct truth_instance
82 {
83     dlib::rgb_pixel rgb_label;
84     dlib::mmod_rect mmod_rect;
85 };
86 
rgb_label_images_to_truth_instances(const dlib::matrix<dlib::rgb_pixel> & instance_label_image,const dlib::matrix<dlib::rgb_pixel> & class_label_image)87 std::vector<truth_instance> rgb_label_images_to_truth_instances(
88     const dlib::matrix<dlib::rgb_pixel>& instance_label_image,
89     const dlib::matrix<dlib::rgb_pixel>& class_label_image
90 )
91 {
92     std::unordered_map<dlib::rgb_pixel, mmod_rect> result_map;
93 
94     DLIB_CASSERT(instance_label_image.nr() == class_label_image.nr());
95     DLIB_CASSERT(instance_label_image.nc() == class_label_image.nc());
96 
97     const auto nr = instance_label_image.nr();
98     const auto nc = instance_label_image.nc();
99 
100     for (int r = 0; r < nr; ++r)
101     {
102         for (int c = 0; c < nc; ++c)
103         {
104             const auto rgb_instance_label = instance_label_image(r, c);
105 
106             if (!is_instance_pixel(rgb_instance_label))
107                 continue;
108 
109             const auto rgb_class_label = class_label_image(r, c);
110             const Voc2012class& voc2012_class = find_voc2012_class(rgb_class_label);
111 
112             const auto i = result_map.find(rgb_instance_label);
113             if (i == result_map.end())
114             {
115                 // Encountered a new instance
116                 result_map[rgb_instance_label] = rectangle(c, r, c, r);
117                 result_map[rgb_instance_label].label = voc2012_class.classlabel;
118             }
119             else
120             {
121                 // Not the first occurrence - update the rect
122                 auto& rect = i->second.rect;
123 
124                 if (c < rect.left())
125                     rect.set_left(c);
126                 else if (c > rect.right())
127                     rect.set_right(c);
128 
129                 if (r > rect.bottom())
130                     rect.set_bottom(r);
131 
132                 DLIB_CASSERT(i->second.label == voc2012_class.classlabel);
133             }
134         }
135     }
136 
137     std::vector<truth_instance> flat_result;
138     flat_result.reserve(result_map.size());
139 
140     for (const auto& i : result_map) {
141         flat_result.push_back(truth_instance{
142             i.first, i.second
143         });
144     }
145 
146     return flat_result;
147 }
148 
149 // ----------------------------------------------------------------------------------------
150 
151 struct truth_image
152 {
153     image_info info;
154     std::vector<truth_instance> truth_instances;
155 };
156 
extract_mmod_rects(const std::vector<truth_instance> & truth_instances)157 std::vector<mmod_rect> extract_mmod_rects(
158     const std::vector<truth_instance>& truth_instances
159 )
160 {
161     std::vector<mmod_rect> mmod_rects(truth_instances.size());
162 
163     std::transform(
164         truth_instances.begin(),
165         truth_instances.end(),
166         mmod_rects.begin(),
167         [](const truth_instance& truth) { return truth.mmod_rect; }
168     );
169 
170     return mmod_rects;
171 }
172 
extract_mmod_rect_vectors(const std::vector<truth_image> & truth_images)173 std::vector<std::vector<mmod_rect>> extract_mmod_rect_vectors(
174     const std::vector<truth_image>& truth_images
175 )
176 {
177     std::vector<std::vector<mmod_rect>> mmod_rects(truth_images.size());
178 
179     const auto extract_mmod_rects_from_truth_image = [](const truth_image& truth_image)
180     {
181         return extract_mmod_rects(truth_image.truth_instances);
182     };
183 
184     std::transform(
185         truth_images.begin(),
186         truth_images.end(),
187         mmod_rects.begin(),
188         extract_mmod_rects_from_truth_image
189     );
190 
191     return mmod_rects;
192 }
193 
train_detection_network(const std::vector<truth_image> & truth_images,unsigned int det_minibatch_size)194 det_bnet_type train_detection_network(
195     const std::vector<truth_image>& truth_images,
196     unsigned int det_minibatch_size
197 )
198 {
199     const double initial_learning_rate = 0.1;
200     const double weight_decay = 0.0001;
201     const double momentum = 0.9;
202     const double min_detector_window_overlap_iou = 0.65;
203 
204     const int target_size = 70;
205     const int min_target_size = 30;
206 
207     mmod_options options(
208         extract_mmod_rect_vectors(truth_images),
209         target_size, min_target_size,
210         min_detector_window_overlap_iou
211     );
212 
213     options.overlaps_ignore = test_box_overlap(0.5, 0.9);
214 
215     det_bnet_type det_net(options);
216 
217     det_net.subnet().layer_details().set_num_filters(options.detector_windows.size());
218 
219     dlib::pipe<det_training_sample> data(200);
220     auto f = [&data, &truth_images, target_size, min_target_size](time_t seed)
221     {
222         dlib::rand rnd(time(0) + seed);
223         matrix<rgb_pixel> input_image;
224 
225         random_cropper cropper;
226         cropper.set_seed(time(0));
227         cropper.set_chip_dims(350, 350);
228 
229         // Usually you want to give the cropper whatever min sizes you passed to the
230         // mmod_options constructor, or very slightly smaller sizes, which is what we do here.
231         cropper.set_min_object_size(target_size - 2, min_target_size - 2);
232         cropper.set_max_rotation_degrees(2);
233 
234         det_training_sample temp;
235 
236         while (data.is_enabled())
237         {
238             // Pick a random input image.
239             const auto random_index = rnd.get_random_32bit_number() % truth_images.size();
240             const auto& truth_image = truth_images[random_index];
241 
242             // Load the input image.
243             load_image(input_image, truth_image.info.image_filename);
244 
245             // Get a random crop of the input.
246             const auto mmod_rects = extract_mmod_rects(truth_image.truth_instances);
247             cropper(input_image, mmod_rects, temp.input_image, temp.mmod_rects);
248 
249             disturb_colors(temp.input_image, rnd);
250 
251             // Push the result to be used by the trainer.
252             data.enqueue(temp);
253         }
254     };
255     std::thread data_loader1([f]() { f(1); });
256     std::thread data_loader2([f]() { f(2); });
257     std::thread data_loader3([f]() { f(3); });
258     std::thread data_loader4([f]() { f(4); });
259 
260     const auto stop_data_loaders = [&]()
261     {
262         data.disable();
263         data_loader1.join();
264         data_loader2.join();
265         data_loader3.join();
266         data_loader4.join();
267     };
268 
269     dnn_trainer<det_bnet_type> det_trainer(det_net, sgd(weight_decay, momentum));
270 
271     try
272     {
273         det_trainer.be_verbose();
274         det_trainer.set_learning_rate(initial_learning_rate);
275         det_trainer.set_synchronization_file("pascal_voc2012_det_trainer_state_file.dat", std::chrono::minutes(10));
276         det_trainer.set_iterations_without_progress_threshold(5000);
277 
278         // Output training parameters.
279         cout << det_trainer << endl;
280 
281         std::vector<matrix<rgb_pixel>> samples;
282         std::vector<std::vector<mmod_rect>> labels;
283 
284         // The main training loop.  Keep making mini-batches and giving them to the trainer.
285         // We will run until the learning rate becomes small enough.
286         while (det_trainer.get_learning_rate() >= 1e-4)
287         {
288             samples.clear();
289             labels.clear();
290 
291             // make a mini-batch
292             det_training_sample temp;
293             while (samples.size() < det_minibatch_size)
294             {
295                 data.dequeue(temp);
296 
297                 samples.push_back(std::move(temp.input_image));
298                 labels.push_back(std::move(temp.mmod_rects));
299             }
300 
301             det_trainer.train_one_step(samples, labels);
302         }
303     }
304     catch (std::exception&)
305     {
306         stop_data_loaders();
307         throw;
308     }
309 
310     // Training done, tell threads to stop and make sure to wait for them to finish before
311     // moving on.
312     stop_data_loaders();
313 
314     // also wait for threaded processing to stop in the trainer.
315     det_trainer.get_net();
316 
317     det_net.clean();
318 
319     return det_net;
320 }
321 
322 // ----------------------------------------------------------------------------------------
323 
keep_only_current_instance(const matrix<rgb_pixel> & rgb_label_image,const rgb_pixel rgb_label)324 matrix<float> keep_only_current_instance(const matrix<rgb_pixel>& rgb_label_image, const rgb_pixel rgb_label)
325 {
326     const auto nr = rgb_label_image.nr();
327     const auto nc = rgb_label_image.nc();
328 
329     matrix<float> result(nr, nc);
330 
331     for (long r = 0; r < nr; ++r)
332     {
333         for (long c = 0; c < nc; ++c)
334         {
335             const auto& index = rgb_label_image(r, c);
336             if (index == rgb_label)
337                 result(r, c) = +1;
338             else if (index == dlib::rgb_pixel(224, 224, 192))
339                 result(r, c) = 0;
340             else
341                 result(r, c) = -1;
342         }
343     }
344 
345     return result;
346 }
347 
train_segmentation_network(const std::vector<truth_image> & truth_images,unsigned int seg_minibatch_size,const std::string & classlabel)348 seg_bnet_type train_segmentation_network(
349     const std::vector<truth_image>& truth_images,
350     unsigned int seg_minibatch_size,
351     const std::string& classlabel
352 )
353 {
354     seg_bnet_type seg_net;
355 
356     const double initial_learning_rate = 0.1;
357     const double weight_decay = 0.0001;
358     const double momentum = 0.9;
359 
360     const std::string synchronization_file_name
361         = "pascal_voc2012_seg_trainer_state_file"
362         + (classlabel.empty() ? "" : ("_" + classlabel))
363         + ".dat";
364 
365     dnn_trainer<seg_bnet_type> seg_trainer(seg_net, sgd(weight_decay, momentum));
366     seg_trainer.be_verbose();
367     seg_trainer.set_learning_rate(initial_learning_rate);
368     seg_trainer.set_synchronization_file(synchronization_file_name, std::chrono::minutes(10));
369     seg_trainer.set_iterations_without_progress_threshold(2000);
370     set_all_bn_running_stats_window_sizes(seg_net, 1000);
371 
372     // Output training parameters.
373     cout << seg_trainer << endl;
374 
375     std::vector<matrix<rgb_pixel>> samples;
376     std::vector<matrix<float>> labels;
377 
378     // Start a bunch of threads that read images from disk and pull out random crops.  It's
379     // important to be sure to feed the GPU fast enough to keep it busy.  Using multiple
380     // thread for this kind of data preparation helps us do that.  Each thread puts the
381     // crops into the data queue.
382     dlib::pipe<seg_training_sample> data(200);
383     auto f = [&data, &truth_images](time_t seed)
384     {
385         dlib::rand rnd(time(0) + seed);
386         matrix<rgb_pixel> input_image;
387         matrix<rgb_pixel> rgb_label_image;
388         matrix<rgb_pixel> rgb_label_chip;
389         seg_training_sample temp;
390         while (data.is_enabled())
391         {
392             // Pick a random input image.
393             const auto random_index = rnd.get_random_32bit_number() % truth_images.size();
394             const auto& truth_image = truth_images[random_index];
395             const auto image_truths = truth_image.truth_instances;
396 
397             if (!image_truths.empty())
398             {
399                 const image_info& info = truth_image.info;
400 
401                 // Load the input image.
402                 load_image(input_image, info.image_filename);
403 
404                 // Load the ground-truth (RGB) instance labels.
405                 load_image(rgb_label_image, info.instance_label_filename);
406 
407                 // Pick a random training instance.
408                 const auto& truth_instance = image_truths[rnd.get_random_32bit_number() % image_truths.size()];
409                 const auto& truth_rect = truth_instance.mmod_rect.rect;
410                 const auto cropping_rect = get_cropping_rect(truth_rect);
411 
412                 // Pick a random crop around the instance.
413                 const auto max_x_translate_amount = static_cast<long>(truth_rect.width() / 10.0);
414                 const auto max_y_translate_amount = static_cast<long>(truth_rect.height() / 10.0);
415 
416                 const auto random_translate = point(
417                     rnd.get_integer_in_range(-max_x_translate_amount, max_x_translate_amount + 1),
418                     rnd.get_integer_in_range(-max_y_translate_amount, max_y_translate_amount + 1)
419                 );
420 
421                 const rectangle random_rect(
422                     cropping_rect.left()   + random_translate.x(),
423                     cropping_rect.top()    + random_translate.y(),
424                     cropping_rect.right()  + random_translate.x(),
425                     cropping_rect.bottom() + random_translate.y()
426                 );
427 
428                 const chip_details chip_details(random_rect, chip_dims(seg_dim, seg_dim));
429 
430                 // Crop the input image.
431                 extract_image_chip(input_image, chip_details, temp.input_image, interpolate_bilinear());
432 
433                 disturb_colors(temp.input_image, rnd);
434 
435                 // Crop the labels correspondingly. However, note that here bilinear
436                 // interpolation would make absolutely no sense - you wouldn't say that
437                 // a bicycle is half-way between an aeroplane and a bird, would you?
438                 extract_image_chip(rgb_label_image, chip_details, rgb_label_chip, interpolate_nearest_neighbor());
439 
440                 // Clear pixels not related to the current instance.
441                 temp.label_image = keep_only_current_instance(rgb_label_chip, truth_instance.rgb_label);
442 
443                 // Push the result to be used by the trainer.
444                 data.enqueue(temp);
445             }
446             else
447             {
448                 // TODO: use background samples as well
449             }
450         }
451     };
452     std::thread data_loader1([f]() { f(1); });
453     std::thread data_loader2([f]() { f(2); });
454     std::thread data_loader3([f]() { f(3); });
455     std::thread data_loader4([f]() { f(4); });
456 
457     const auto stop_data_loaders = [&]()
458     {
459         data.disable();
460         data_loader1.join();
461         data_loader2.join();
462         data_loader3.join();
463         data_loader4.join();
464     };
465 
466     try
467     {
468         // The main training loop.  Keep making mini-batches and giving them to the trainer.
469         // We will run until the learning rate has dropped by a factor of 1e-4.
470         while (seg_trainer.get_learning_rate() >= 1e-4)
471         {
472             samples.clear();
473             labels.clear();
474 
475             // make a mini-batch
476             seg_training_sample temp;
477             while (samples.size() < seg_minibatch_size)
478             {
479                 data.dequeue(temp);
480 
481                 samples.push_back(std::move(temp.input_image));
482                 labels.push_back(std::move(temp.label_image));
483             }
484 
485             seg_trainer.train_one_step(samples, labels);
486         }
487     }
488     catch (std::exception&)
489     {
490         stop_data_loaders();
491         throw;
492     }
493 
494     // Training done, tell threads to stop and make sure to wait for them to finish before
495     // moving on.
496     stop_data_loaders();
497 
498     // also wait for threaded processing to stop in the trainer.
499     seg_trainer.get_net();
500 
501     seg_net.clean();
502 
503     return seg_net;
504 }
505 
506 // ----------------------------------------------------------------------------------------
507 
ignore_overlapped_boxes(std::vector<truth_instance> & truth_instances,const test_box_overlap & overlaps)508 int ignore_overlapped_boxes(
509     std::vector<truth_instance>& truth_instances,
510     const test_box_overlap& overlaps
511 )
512 /*!
513     ensures
514         - Whenever two rectangles in boxes overlap, according to overlaps(), we set the
515           smallest box to ignore.
516         - returns the number of newly ignored boxes.
517 !*/
518 {
519     int num_ignored = 0;
520     for (size_t i = 0, end = truth_instances.size(); i < end; ++i)
521     {
522         auto& box_i = truth_instances[i].mmod_rect;
523         if (box_i.ignore)
524             continue;
525         for (size_t j = i+1; j < end; ++j)
526         {
527             auto& box_j = truth_instances[j].mmod_rect;
528             if (box_j.ignore)
529                 continue;
530             if (overlaps(box_i, box_j))
531             {
532                 ++num_ignored;
533                 if(box_i.rect.area() < box_j.rect.area())
534                     box_i.ignore = true;
535                 else
536                     box_j.ignore = true;
537             }
538         }
539     }
540     return num_ignored;
541 }
542 
load_truth_instances(const image_info & info)543 std::vector<truth_instance> load_truth_instances(const image_info& info)
544 {
545     matrix<rgb_pixel> instance_label_image;
546     matrix<rgb_pixel> class_label_image;
547 
548     load_image(instance_label_image, info.instance_label_filename);
549     load_image(class_label_image, info.class_label_filename);
550 
551     return rgb_label_images_to_truth_instances(instance_label_image, class_label_image);
552 }
553 
load_all_truth_instances(const std::vector<image_info> & listing)554 std::vector<std::vector<truth_instance>> load_all_truth_instances(const std::vector<image_info>& listing)
555 {
556     std::vector<std::vector<truth_instance>> truth_instances(listing.size());
557 
558     std::transform(
559 #if __cplusplus >= 201703L || (defined(_MSVC_LANG) && _MSVC_LANG >= 201703L)
560         std::execution::par,
561 #endif // __cplusplus >= 201703L
562         listing.begin(),
563         listing.end(),
564         truth_instances.begin(),
565         load_truth_instances
566     );
567 
568     return truth_instances;
569 }
570 
571 // ----------------------------------------------------------------------------------------
572 
filter_based_on_classlabel(const std::vector<truth_image> & truth_images,const std::vector<std::string> & desired_classlabels)573 std::vector<truth_image> filter_based_on_classlabel(
574     const std::vector<truth_image>& truth_images,
575     const std::vector<std::string>& desired_classlabels
576 )
577 {
578     std::vector<truth_image> result;
579 
580     const auto represents_desired_class = [&desired_classlabels](const truth_instance& truth_instance) {
581         return std::find(
582             desired_classlabels.begin(),
583             desired_classlabels.end(),
584             truth_instance.mmod_rect.label
585         ) != desired_classlabels.end();
586     };
587 
588     for (const auto& input : truth_images)
589     {
590         const auto has_desired_class = std::any_of(
591             input.truth_instances.begin(),
592             input.truth_instances.end(),
593             represents_desired_class
594         );
595 
596         if (has_desired_class) {
597 
598             // NB: This keeps only MMOD rects belonging to any of the desired classes.
599             //     A reasonable alternative could be to keep all rects, but mark those
600             //     belonging in other classes to be ignored during training.
601             std::vector<truth_instance> temp;
602             std::copy_if(
603                 input.truth_instances.begin(),
604                 input.truth_instances.end(),
605                 std::back_inserter(temp),
606                 represents_desired_class
607             );
608 
609             result.push_back(truth_image{ input.info, temp });
610         }
611     }
612 
613     return result;
614 }
615 
616 // Ignore truth boxes that overlap too much, are too small, or have a large aspect ratio.
ignore_some_truth_boxes(std::vector<truth_image> & truth_images)617 void ignore_some_truth_boxes(std::vector<truth_image>& truth_images)
618 {
619     for (auto& i : truth_images)
620     {
621         auto& truth_instances = i.truth_instances;
622 
623         ignore_overlapped_boxes(truth_instances, test_box_overlap(0.90, 0.95));
624 
625         for (auto& truth : truth_instances)
626         {
627             if (truth.mmod_rect.ignore)
628                 continue;
629 
630             const auto& rect = truth.mmod_rect.rect;
631 
632             constexpr unsigned long min_width  = 35;
633             constexpr unsigned long min_height = 35;
634             if (rect.width() < min_width && rect.height() < min_height)
635             {
636                 truth.mmod_rect.ignore = true;
637                 continue;
638             }
639 
640             constexpr double max_aspect_ratio_width_to_height = 3.0;
641             constexpr double max_aspect_ratio_height_to_width = 1.5;
642             const double aspect_ratio_width_to_height = rect.width() / static_cast<double>(rect.height());
643             const double aspect_ratio_height_to_width = 1.0 / aspect_ratio_width_to_height;
644             const bool is_aspect_ratio_too_large
645                 =  aspect_ratio_width_to_height > max_aspect_ratio_width_to_height
646                 || aspect_ratio_height_to_width > max_aspect_ratio_height_to_width;
647 
648             if (is_aspect_ratio_too_large)
649                 truth.mmod_rect.ignore = true;
650         }
651     }
652 }
653 
654 // Filter images that have no (non-ignored) truth
filter_images_with_no_truth(const std::vector<truth_image> & truth_images)655 std::vector<truth_image> filter_images_with_no_truth(const std::vector<truth_image>& truth_images)
656 {
657     std::vector<truth_image> result;
658 
659     for (const auto& truth_image : truth_images)
660     {
661         const auto ignored = [](const truth_instance& truth) { return truth.mmod_rect.ignore; };
662         const auto& truth_instances = truth_image.truth_instances;
663         if (!std::all_of(truth_instances.begin(), truth_instances.end(), ignored))
664             result.push_back(truth_image);
665     }
666 
667     return result;
668 }
669 
main(int argc,char ** argv)670 int main(int argc, char** argv) try
671 {
672     if (argc < 2)
673     {
674         cout << "To run this program you need a copy of the PASCAL VOC2012 dataset." << endl;
675         cout << endl;
676         cout << "You call this program like this: " << endl;
677         cout << "./dnn_instance_segmentation_train_ex /path/to/VOC2012 [det-minibatch-size] [seg-minibatch-size] [class-1] [class-2] [class-3] ..." << endl;
678         return 1;
679     }
680 
681     cout << "\nSCANNING PASCAL VOC2012 DATASET\n" << endl;
682 
683     const auto listing = get_pascal_voc2012_train_listing(argv[1]);
684     cout << "images in entire dataset: " << listing.size() << endl;
685     if (listing.size() == 0)
686     {
687         cout << "Didn't find the VOC2012 dataset. " << endl;
688         return 1;
689     }
690 
691     // mini-batches smaller than the default can be used with GPUs having less memory
692     const unsigned int det_minibatch_size = argc >= 3 ? std::stoi(argv[2]) : 35;
693     const unsigned int seg_minibatch_size = argc >= 4 ? std::stoi(argv[3]) : 100;
694     cout << "det mini-batch size: " << det_minibatch_size << endl;
695     cout << "seg mini-batch size: " << seg_minibatch_size << endl;
696 
697     std::vector<std::string> desired_classlabels;
698 
699     for (int arg = 4; arg < argc; ++arg)
700         desired_classlabels.push_back(argv[arg]);
701 
702     if (desired_classlabels.empty())
703     {
704         desired_classlabels.push_back("bicycle");
705         desired_classlabels.push_back("car");
706         desired_classlabels.push_back("cat");
707     }
708 
709     cout << "desired classlabels:";
710     for (const auto& desired_classlabel : desired_classlabels)
711         cout << " " << desired_classlabel;
712     cout << endl;
713 
714     // extract the MMOD rects
715     cout << endl << "Extracting all truth instances...";
716     const auto truth_instances = load_all_truth_instances(listing);
717     cout << " Done!" << endl << endl;
718 
719     DLIB_CASSERT(listing.size() == truth_instances.size());
720 
721     std::vector<truth_image> original_truth_images;
722     for (size_t i = 0, end = listing.size(); i < end; ++i)
723     {
724         original_truth_images.push_back(truth_image{
725             listing[i], truth_instances[i]
726         });
727     }
728 
729     auto truth_images_filtered_by_class = filter_based_on_classlabel(original_truth_images, desired_classlabels);
730 
731     cout << "images in dataset filtered by class: " << truth_images_filtered_by_class.size() << endl;
732 
733     ignore_some_truth_boxes(truth_images_filtered_by_class);
734     const auto truth_images = filter_images_with_no_truth(truth_images_filtered_by_class);
735 
736     cout << "images in dataset after ignoring some truth boxes: " << truth_images.size() << endl;
737 
738     // First train an object detector network (loss_mmod).
739     cout << endl << "Training detector network:" << endl;
740     const auto det_net = train_detection_network(truth_images, det_minibatch_size);
741 
742     // Then train mask predictors (segmentation).
743     std::map<std::string, seg_bnet_type> seg_nets_by_class;
744 
745     // This flag controls if a separate mask predictor is trained for each class.
746     // Note that it would also be possible to train a separate mask predictor for
747     // class groups, each containing somehow similar classes -- for example, one
748     // mask predictor for cars and buses, another for cats and dogs, and so on.
749     constexpr bool separate_seg_net_for_each_class = true;
750 
751     if (separate_seg_net_for_each_class)
752     {
753         for (const auto& classlabel : desired_classlabels)
754         {
755             // Consider only the truth images belonging to this class.
756             const auto class_images = filter_based_on_classlabel(truth_images, { classlabel });
757 
758             cout << endl << "Training segmentation network for class " << classlabel << ":" << endl;
759             seg_nets_by_class[classlabel] = train_segmentation_network(class_images, seg_minibatch_size, classlabel);
760         }
761     }
762     else
763     {
764         cout << "Training a single segmentation network:" << endl;
765         seg_nets_by_class[""] = train_segmentation_network(truth_images, seg_minibatch_size, "");
766     }
767 
768     cout << "Saving networks" << endl;
769     serialize(instance_segmentation_net_filename) << det_net << seg_nets_by_class;
770 }
771 
772 catch(std::exception& e)
773 {
774     cout << e.what() << endl;
775 }
776 
777