1 // Copyright (C) 2011 Davis E. King (davis@dlib.net) 2 // License: Boost Software License See LICENSE.txt for the full license. 3 #ifndef DLIB_STRUCTURAL_SVM_PRObLEM_THREADED_Hh_ 4 #define DLIB_STRUCTURAL_SVM_PRObLEM_THREADED_Hh_ 5 6 #include "structural_svm_problem_threaded_abstract.h" 7 #include "../algs.h" 8 #include <vector> 9 #include "structural_svm_problem.h" 10 #include "../matrix.h" 11 #include "sparse_vector.h" 12 #include <iostream> 13 #include "../threads.h" 14 #include "../misc_api.h" 15 #include "../statistics.h" 16 17 namespace dlib 18 { 19 20 // ---------------------------------------------------------------------------------------- 21 22 template < 23 typename matrix_type_, 24 typename feature_vector_type_ = matrix_type_ 25 > 26 class structural_svm_problem_threaded : public structural_svm_problem<matrix_type_,feature_vector_type_> 27 { 28 public: 29 30 typedef matrix_type_ matrix_type; 31 typedef typename matrix_type::type scalar_type; 32 typedef feature_vector_type_ feature_vector_type; 33 structural_svm_problem_threaded(unsigned long num_threads)34 explicit structural_svm_problem_threaded ( 35 unsigned long num_threads 36 ) : 37 tp(num_threads), 38 num_iterations_executed(0) 39 {} 40 get_num_threads()41 unsigned long get_num_threads ( 42 ) const { return tp.num_threads_in_pool(); } 43 44 private: 45 46 struct binder 47 { binderbinder48 binder ( 49 const structural_svm_problem_threaded& self_, 50 const matrix_type& w_, 51 matrix_type& subgradient_, 52 scalar_type& total_loss_, 53 bool buffer_subgradients_locally_ 54 ) : self(self_), w(w_), subgradient(subgradient_), total_loss(total_loss_), 55 buffer_subgradients_locally(buffer_subgradients_locally_){} 56 call_oraclebinder57 void call_oracle ( 58 long begin, 59 long end 60 ) 61 { 62 // If we are only going to call the separation oracle once then don't run 63 // the slightly more complex for loop version of this code. Or if we just 64 // don't want to run the complex buffering one. The code later on decides 65 // if we should do the buffering based on how long it takes to execute. We 66 // do this because, when the subgradient is really high dimensional it can 67 // take a lot of time to add them together. So we might want to avoid 68 // doing that. 69 if (end-begin <= 1 || !buffer_subgradients_locally) 70 { 71 scalar_type loss; 72 feature_vector_type ftemp; 73 for (long i = begin; i < end; ++i) 74 { 75 self.separation_oracle_cached(i, w, loss, ftemp); 76 77 auto_mutex lock(self.accum_mutex); 78 total_loss += loss; 79 add_to(subgradient, ftemp); 80 } 81 } 82 else 83 { 84 scalar_type loss = 0; 85 matrix_type faccum(subgradient.size(),1); 86 faccum = 0; 87 88 feature_vector_type ftemp; 89 90 for (long i = begin; i < end; ++i) 91 { 92 scalar_type loss_temp; 93 self.separation_oracle_cached(i, w, loss_temp, ftemp); 94 loss += loss_temp; 95 add_to(faccum, ftemp); 96 } 97 98 auto_mutex lock(self.accum_mutex); 99 total_loss += loss; 100 add_to(subgradient, faccum); 101 } 102 } 103 104 const structural_svm_problem_threaded& self; 105 const matrix_type& w; 106 matrix_type& subgradient; 107 scalar_type& total_loss; 108 bool buffer_subgradients_locally; 109 }; 110 111 call_separation_oracle_on_all_samples(const matrix_type & w,matrix_type & subgradient,scalar_type & total_loss)112 virtual void call_separation_oracle_on_all_samples ( 113 const matrix_type& w, 114 matrix_type& subgradient, 115 scalar_type& total_loss 116 ) const 117 { 118 ++num_iterations_executed; 119 120 const uint64 start_time = ts.get_timestamp(); 121 122 bool buffer_subgradients_locally = with_buffer_time.mean() < without_buffer_time.mean(); 123 124 // every 50 iterations we should try to flip the buffering scheme to see if 125 // doing it the other way might be better. 126 if ((num_iterations_executed%50) == 0) 127 { 128 buffer_subgradients_locally = !buffer_subgradients_locally; 129 } 130 131 binder b(*this, w, subgradient, total_loss, buffer_subgradients_locally); 132 parallel_for_blocked(tp, 0, this->get_num_samples(), b, &binder::call_oracle); 133 134 const uint64 stop_time = ts.get_timestamp(); 135 136 if (buffer_subgradients_locally) 137 with_buffer_time.add(stop_time-start_time); 138 else 139 without_buffer_time.add(stop_time-start_time); 140 141 } 142 143 mutable thread_pool tp; 144 mutable mutex accum_mutex; 145 mutable timestamper ts; 146 mutable running_stats<double> with_buffer_time; 147 mutable running_stats<double> without_buffer_time; 148 mutable unsigned long num_iterations_executed; 149 }; 150 151 // ---------------------------------------------------------------------------------------- 152 153 } 154 155 #endif // DLIB_STRUCTURAL_SVM_PRObLEM_THREADED_Hh_ 156 157 158