1 // Copyright (C) 2011  Davis E. King (davis@dlib.net)
2 // License: Boost Software License   See LICENSE.txt for the full license.
3 #ifndef DLIB_STRUCTURAL_SVM_PRObLEM_THREADED_Hh_
4 #define DLIB_STRUCTURAL_SVM_PRObLEM_THREADED_Hh_
5 
6 #include "structural_svm_problem_threaded_abstract.h"
7 #include "../algs.h"
8 #include <vector>
9 #include "structural_svm_problem.h"
10 #include "../matrix.h"
11 #include "sparse_vector.h"
12 #include <iostream>
13 #include "../threads.h"
14 #include "../misc_api.h"
15 #include "../statistics.h"
16 
17 namespace dlib
18 {
19 
20 // ----------------------------------------------------------------------------------------
21 
22     template <
23         typename matrix_type_,
24         typename feature_vector_type_ = matrix_type_
25         >
26     class structural_svm_problem_threaded : public structural_svm_problem<matrix_type_,feature_vector_type_>
27     {
28     public:
29 
30         typedef matrix_type_ matrix_type;
31         typedef typename matrix_type::type scalar_type;
32         typedef feature_vector_type_ feature_vector_type;
33 
structural_svm_problem_threaded(unsigned long num_threads)34         explicit structural_svm_problem_threaded (
35             unsigned long num_threads
36         ) :
37             tp(num_threads),
38             num_iterations_executed(0)
39         {}
40 
get_num_threads()41         unsigned long get_num_threads (
42         ) const { return tp.num_threads_in_pool(); }
43 
44     private:
45 
46         struct binder
47         {
binderbinder48             binder (
49                 const structural_svm_problem_threaded& self_,
50                 const matrix_type& w_,
51                 matrix_type& subgradient_,
52                 scalar_type& total_loss_,
53                 bool buffer_subgradients_locally_
54             ) : self(self_), w(w_), subgradient(subgradient_), total_loss(total_loss_),
55                 buffer_subgradients_locally(buffer_subgradients_locally_){}
56 
call_oraclebinder57             void call_oracle (
58                 long begin,
59                 long end
60             )
61             {
62                 // If we are only going to call the separation oracle once then don't run
63                 // the slightly more complex for loop version of this code.  Or if we just
64                 // don't want to run the complex buffering one.  The code later on decides
65                 // if we should do the buffering based on how long it takes to execute.  We
66                 // do this because, when the subgradient is really high dimensional it can
67                 // take a lot of time to add them together.  So we might want to avoid
68                 // doing that.
69                 if (end-begin <= 1 || !buffer_subgradients_locally)
70                 {
71                     scalar_type loss;
72                     feature_vector_type ftemp;
73                     for (long i = begin; i < end; ++i)
74                     {
75                         self.separation_oracle_cached(i, w, loss, ftemp);
76 
77                         auto_mutex lock(self.accum_mutex);
78                         total_loss += loss;
79                         add_to(subgradient, ftemp);
80                     }
81                 }
82                 else
83                 {
84                     scalar_type loss = 0;
85                     matrix_type faccum(subgradient.size(),1);
86                     faccum = 0;
87 
88                     feature_vector_type ftemp;
89 
90                     for (long i = begin; i < end; ++i)
91                     {
92                         scalar_type loss_temp;
93                         self.separation_oracle_cached(i, w, loss_temp, ftemp);
94                         loss += loss_temp;
95                         add_to(faccum, ftemp);
96                     }
97 
98                     auto_mutex lock(self.accum_mutex);
99                     total_loss += loss;
100                     add_to(subgradient, faccum);
101                 }
102             }
103 
104             const structural_svm_problem_threaded& self;
105             const matrix_type& w;
106             matrix_type& subgradient;
107             scalar_type& total_loss;
108             bool buffer_subgradients_locally;
109         };
110 
111 
call_separation_oracle_on_all_samples(const matrix_type & w,matrix_type & subgradient,scalar_type & total_loss)112         virtual void call_separation_oracle_on_all_samples (
113             const matrix_type& w,
114             matrix_type& subgradient,
115             scalar_type& total_loss
116         ) const
117         {
118             ++num_iterations_executed;
119 
120             const uint64 start_time = ts.get_timestamp();
121 
122             bool buffer_subgradients_locally = with_buffer_time.mean() < without_buffer_time.mean();
123 
124             // every 50 iterations we should try to flip the buffering scheme to see if
125             // doing it the other way might be better.
126             if ((num_iterations_executed%50) == 0)
127             {
128                 buffer_subgradients_locally = !buffer_subgradients_locally;
129             }
130 
131             binder b(*this, w, subgradient, total_loss, buffer_subgradients_locally);
132             parallel_for_blocked(tp, 0, this->get_num_samples(), b, &binder::call_oracle);
133 
134             const uint64 stop_time = ts.get_timestamp();
135 
136             if (buffer_subgradients_locally)
137                 with_buffer_time.add(stop_time-start_time);
138             else
139                 without_buffer_time.add(stop_time-start_time);
140 
141         }
142 
143         mutable thread_pool tp;
144         mutable mutex accum_mutex;
145         mutable timestamper ts;
146         mutable running_stats<double> with_buffer_time;
147         mutable running_stats<double> without_buffer_time;
148         mutable unsigned long num_iterations_executed;
149     };
150 
151 // ----------------------------------------------------------------------------------------
152 
153 }
154 
155 #endif // DLIB_STRUCTURAL_SVM_PRObLEM_THREADED_Hh_
156 
157 
158