1 // Copyright (C) 2011  Davis E. King (davis@dlib.net)
2 // License: Boost Software License   See LICENSE.txt for the full license.
3 #ifndef DLIB_STRUCTURAL_SVM_DISTRIBUTeD_Hh_
4 #define DLIB_STRUCTURAL_SVM_DISTRIBUTeD_Hh_
5 
6 #include <memory>
7 #include <iostream>
8 #include <vector>
9 
10 #include "structural_svm_distributed_abstract.h"
11 #include "structural_svm_problem.h"
12 #include "../bridge.h"
13 #include "../misc_api.h"
14 #include "../statistics.h"
15 #include "../threads.h"
16 #include "../pipe.h"
17 #include "../type_safe_union.h"
18 
19 
20 namespace dlib
21 {
22 
23 // ----------------------------------------------------------------------------------------
24 
25     namespace impl
26     {
27 
28         template <typename matrix_type>
29         struct oracle_response
30         {
31             typedef typename matrix_type::type scalar_type;
32 
33             matrix_type subgradient;
34             scalar_type loss;
35             long num;
36 
swaporacle_response37             friend void swap (oracle_response& a, oracle_response& b)
38             {
39                 a.subgradient.swap(b.subgradient);
40                 std::swap(a.loss, b.loss);
41                 std::swap(a.num, b.num);
42             }
43 
serializeoracle_response44             friend void serialize (const oracle_response& item, std::ostream& out)
45             {
46                 serialize(item.subgradient, out);
47                 dlib::serialize(item.loss, out);
48                 dlib::serialize(item.num, out);
49             }
50 
deserializeoracle_response51             friend void deserialize (oracle_response& item, std::istream& in)
52             {
53                 deserialize(item.subgradient, in);
54                 dlib::deserialize(item.loss, in);
55                 dlib::deserialize(item.num, in);
56             }
57         };
58 
59     // ----------------------------------------------------------------------------------------
60 
61         template <typename matrix_type>
62         struct oracle_request
63         {
64             typedef typename matrix_type::type scalar_type;
65 
66             matrix_type current_solution;
67             scalar_type saved_current_risk_gap;
68             bool skip_cache;
69             bool converged;
70 
swaporacle_request71             friend void swap (oracle_request& a, oracle_request& b)
72             {
73                 a.current_solution.swap(b.current_solution);
74                 std::swap(a.saved_current_risk_gap, b.saved_current_risk_gap);
75                 std::swap(a.skip_cache, b.skip_cache);
76                 std::swap(a.converged, b.converged);
77             }
78 
serializeoracle_request79             friend void serialize (const oracle_request& item, std::ostream& out)
80             {
81                 serialize(item.current_solution, out);
82                 dlib::serialize(item.saved_current_risk_gap, out);
83                 dlib::serialize(item.skip_cache, out);
84                 dlib::serialize(item.converged, out);
85             }
86 
deserializeoracle_request87             friend void deserialize (oracle_request& item, std::istream& in)
88             {
89                 deserialize(item.current_solution, in);
90                 dlib::deserialize(item.saved_current_risk_gap, in);
91                 dlib::deserialize(item.skip_cache, in);
92                 dlib::deserialize(item.converged, in);
93             }
94         };
95 
96     }
97 
98 // ----------------------------------------------------------------------------------------
99 
100     class svm_struct_processing_node : noncopyable
101     {
102     public:
103 
104         template <
105             typename T,
106             typename U
107             >
svm_struct_processing_node(const structural_svm_problem<T,U> & problem,unsigned short port,unsigned short num_threads)108         svm_struct_processing_node (
109             const structural_svm_problem<T,U>& problem,
110             unsigned short port,
111             unsigned short num_threads
112         )
113         {
114             // make sure requires clause is not broken
115             DLIB_ASSERT(port != 0 && problem.get_num_samples() != 0 &&
116                         problem.get_num_dimensions() != 0,
117                 "\t svm_struct_processing_node()"
118                 << "\n\t Invalid arguments were given to this function"
119                 << "\n\t port: " << port
120                 << "\n\t problem.get_num_samples():    " << problem.get_num_samples()
121                 << "\n\t problem.get_num_dimensions(): " << problem.get_num_dimensions()
122                 << "\n\t this: " << this
123                 );
124 
125             the_problem.reset(new node_type<T,U>(problem, port, num_threads));
126         }
127 
128     private:
129 
130         struct base
131         {
~basebase132             virtual ~base(){}
133         };
134 
135         template <
136             typename matrix_type,
137             typename feature_vector_type
138             >
139         class node_type : public base, threaded_object
140         {
141         public:
142             typedef typename matrix_type::type scalar_type;
143 
node_type(const structural_svm_problem<matrix_type,feature_vector_type> & prob,unsigned short port,unsigned long num_threads)144             node_type(
145                 const structural_svm_problem<matrix_type,feature_vector_type>& prob,
146                 unsigned short port,
147                 unsigned long num_threads
148             ) : in(3),out(3), problem(prob), tp(num_threads)
149             {
150                 b.reconfigure(listen_on_port(port), receive(in), transmit(out));
151 
152                 start();
153             }
154 
~node_type()155             ~node_type()
156             {
157                 in.disable();
158                 out.disable();
159                 wait();
160             }
161 
162         private:
163 
thread()164             void thread()
165             {
166                 using namespace impl;
167                 tsu_in msg;
168                 tsu_out temp;
169 
170                 timestamper ts;
171                 running_stats<double> with_buffer_time;
172                 running_stats<double> without_buffer_time;
173                 unsigned long num_iterations_executed = 0;
174 
175                 while (in.dequeue(msg))
176                 {
177                     // initialize the cache and compute psi_true.
178                     if (cache.size() == 0)
179                     {
180                         cache.resize(problem.get_num_samples());
181                         for (unsigned long i = 0; i < cache.size(); ++i)
182                             cache[i].init(&problem,i);
183 
184                         psi_true.set_size(problem.get_num_dimensions(),1);
185                         psi_true = 0;
186 
187                         const unsigned long num = problem.get_num_samples();
188                         feature_vector_type ftemp;
189                         for (unsigned long i = 0; i < num; ++i)
190                         {
191                             cache[i].get_truth_joint_feature_vector_cached(ftemp);
192 
193                             subtract_from(psi_true, ftemp);
194                         }
195                     }
196 
197 
198                     if (msg.template contains<bridge_status>() &&
199                         msg.template get<bridge_status>().is_connected)
200                     {
201                         temp = problem.get_num_dimensions();
202                         out.enqueue(temp);
203 
204                     }
205                     else if (msg.template contains<oracle_request<matrix_type> >())
206                     {
207                         ++num_iterations_executed;
208 
209                         const oracle_request<matrix_type>& req = msg.template get<oracle_request<matrix_type> >();
210 
211                         oracle_response<matrix_type>& data = temp.template get<oracle_response<matrix_type> >();
212 
213                         data.subgradient = psi_true;
214                         data.loss = 0;
215 
216                         data.num = problem.get_num_samples();
217 
218                         const uint64 start_time = ts.get_timestamp();
219 
220                         // pick fastest buffering strategy
221                         bool buffer_subgradients_locally = with_buffer_time.mean() < without_buffer_time.mean();
222 
223                         // every 50 iterations we should try to flip the buffering scheme to see if
224                         // doing it the other way might be better.
225                         if ((num_iterations_executed%50) == 0)
226                         {
227                             buffer_subgradients_locally = !buffer_subgradients_locally;
228                         }
229 
230                         binder b(*this, req, data, buffer_subgradients_locally);
231                         parallel_for_blocked(tp, 0, data.num, b, &binder::call_oracle);
232 
233                         const uint64 stop_time = ts.get_timestamp();
234                         if (buffer_subgradients_locally)
235                             with_buffer_time.add(stop_time-start_time);
236                         else
237                             without_buffer_time.add(stop_time-start_time);
238 
239                         out.enqueue(temp);
240                     }
241                 }
242             }
243 
244             struct binder
245             {
binderbinder246                 binder (
247                     const node_type& self_,
248                     const impl::oracle_request<matrix_type>& req_,
249                     impl::oracle_response<matrix_type>& data_,
250                     bool buffer_subgradients_locally_
251                 ) : self(self_), req(req_), data(data_),
252                     buffer_subgradients_locally(buffer_subgradients_locally_) {}
253 
call_oraclebinder254                 void call_oracle (
255                     long begin,
256                     long end
257                 )
258                 {
259                     // If we are only going to call the separation oracle once then don't
260                     // run the slightly more complex for loop version of this code.  Or if
261                     // we just don't want to run the complex buffering one.  The code later
262                     // on decides if we should do the buffering based on how long it takes
263                     // to execute.  We do this because, when the subgradient is really high
264                     // dimensional it can take a lot of time to add them together.  So we
265                     // might want to avoid doing that.
266                     if (end-begin <= 1 || !buffer_subgradients_locally)
267                     {
268                         scalar_type loss;
269                         feature_vector_type ftemp;
270                         for (long i = begin; i < end; ++i)
271                         {
272                             self.cache[i].separation_oracle_cached(req.converged,
273                                                                    req.skip_cache,
274                                                                    req.saved_current_risk_gap,
275                                                                    req.current_solution,
276                                                                    loss,
277                                                                    ftemp);
278 
279                             auto_mutex lock(self.accum_mutex);
280                             data.loss += loss;
281                             add_to(data.subgradient, ftemp);
282                         }
283                     }
284                     else
285                     {
286                         scalar_type loss = 0;
287                         matrix_type faccum(data.subgradient.size(),1);
288                         faccum = 0;
289 
290                         feature_vector_type ftemp;
291 
292                         for (long i = begin; i < end; ++i)
293                         {
294                             scalar_type loss_temp;
295                             self.cache[i].separation_oracle_cached(req.converged,
296                                                                    req.skip_cache,
297                                                                    req.saved_current_risk_gap,
298                                                                    req.current_solution,
299                                                                    loss_temp,
300                                                                    ftemp);
301                             loss += loss_temp;
302                             add_to(faccum, ftemp);
303                         }
304 
305                         auto_mutex lock(self.accum_mutex);
306                         data.loss += loss;
307                         add_to(data.subgradient, faccum);
308                     }
309                 }
310 
311                 const node_type& self;
312                 const impl::oracle_request<matrix_type>& req;
313                 impl::oracle_response<matrix_type>& data;
314                 bool buffer_subgradients_locally;
315             };
316 
317 
318 
319             typedef type_safe_union<impl::oracle_request<matrix_type>, bridge_status> tsu_in;
320             typedef type_safe_union<impl::oracle_response<matrix_type> , long> tsu_out;
321 
322             pipe<tsu_in> in;
323             pipe<tsu_out> out;
324             bridge b;
325 
326             mutable matrix_type psi_true;
327             const structural_svm_problem<matrix_type,feature_vector_type>& problem;
328             mutable std::vector<cache_element_structural_svm<structural_svm_problem<matrix_type,feature_vector_type> > > cache;
329 
330             mutable thread_pool tp;
331             mutex accum_mutex;
332         };
333 
334 
335         std::unique_ptr<base> the_problem;
336     };
337 
338 // ----------------------------------------------------------------------------------------
339 
340     class svm_struct_controller_node : noncopyable
341     {
342     public:
343 
svm_struct_controller_node()344         svm_struct_controller_node (
345         ) :
346             eps(0.001),
347             max_iterations(10000),
348             cache_based_eps(std::numeric_limits<double>::infinity()),
349             verbose(false),
350             C(1)
351         {}
352 
get_cache_based_epsilon()353         double get_cache_based_epsilon (
354         ) const
355         {
356             return cache_based_eps;
357         }
358 
set_cache_based_epsilon(double eps_)359         void set_cache_based_epsilon (
360             double eps_
361         )
362         {
363             // make sure requires clause is not broken
364             DLIB_ASSERT(eps_ > 0,
365                 "\t void svm_struct_controller_node::set_cache_based_epsilon()"
366                 << "\n\t eps_ must be greater than 0"
367                 << "\n\t eps_: " << eps_
368                 << "\n\t this: " << this
369                 );
370 
371             cache_based_eps = eps_;
372         }
373 
set_epsilon(double eps_)374         void set_epsilon (
375             double eps_
376         )
377         {
378             // make sure requires clause is not broken
379             DLIB_ASSERT(eps_ > 0,
380                 "\t void svm_struct_controller_node::set_epsilon()"
381                 << "\n\t eps_ must be greater than 0"
382                 << "\n\t eps_: " << eps_
383                 << "\n\t this: " << this
384                 );
385 
386             eps = eps_;
387         }
388 
get_epsilon()389         double get_epsilon (
390         ) const { return eps; }
391 
get_max_iterations()392         unsigned long get_max_iterations (
393         ) const { return max_iterations; }
394 
set_max_iterations(unsigned long max_iter)395         void set_max_iterations (
396             unsigned long max_iter
397         )
398         {
399             max_iterations = max_iter;
400         }
401 
be_verbose()402         void be_verbose (
403         )
404         {
405             verbose = true;
406         }
407 
be_quiet()408         void be_quiet(
409         )
410         {
411             verbose = false;
412         }
413 
add_nuclear_norm_regularizer(long first_dimension,long rows,long cols,double regularization_strength)414         void add_nuclear_norm_regularizer (
415             long first_dimension,
416             long rows,
417             long cols,
418             double regularization_strength
419         )
420         {
421             // make sure requires clause is not broken
422             DLIB_ASSERT(0 <= first_dimension  &&
423                 0 <= rows && 0 <= cols &&
424                 0 < regularization_strength,
425                 "\t void svm_struct_controller_node::add_nuclear_norm_regularizer()"
426                 << "\n\t Invalid arguments were given to this function."
427                 << "\n\t first_dimension:         " << first_dimension
428                 << "\n\t rows:                    " << rows
429                 << "\n\t cols:                    " << cols
430                 << "\n\t regularization_strength: " << regularization_strength
431                 << "\n\t this: " << this
432                 );
433 
434             impl::nuclear_norm_regularizer temp;
435             temp.first_dimension = first_dimension;
436             temp.nr = rows;
437             temp.nc = cols;
438             temp.regularization_strength = regularization_strength;
439             nuclear_norm_regularizers.push_back(temp);
440         }
441 
num_nuclear_norm_regularizers()442         unsigned long num_nuclear_norm_regularizers (
443         ) const { return nuclear_norm_regularizers.size(); }
444 
clear_nuclear_norm_regularizers()445         void clear_nuclear_norm_regularizers (
446         ) { nuclear_norm_regularizers.clear(); }
447 
448 
get_c()449         double get_c (
450         ) const { return C; }
451 
set_c(double C_)452         void set_c (
453             double C_
454         )
455         {
456             // make sure requires clause is not broken
457             DLIB_ASSERT(C_ > 0,
458                 "\t void svm_struct_controller_node::set_c()"
459                 << "\n\t C_ must be greater than 0"
460                 << "\n\t C_:    " << C_
461                 << "\n\t this: " << this
462                 );
463 
464             C = C_;
465         }
466 
add_processing_node(const network_address & addr)467         void add_processing_node (
468             const network_address& addr
469         )
470         {
471             // make sure requires clause is not broken
472             DLIB_ASSERT(addr.port != 0,
473                 "\t void svm_struct_controller_node::add_processing_node()"
474                 << "\n\t Invalid inputs were given to this function"
475                 << "\n\t addr.host_address:   " << addr.host_address
476                 << "\n\t addr.port: " << addr.port
477                 << "\n\t this: " << this
478                 );
479 
480             // check if this address is already registered
481             for (unsigned long i = 0; i < nodes.size(); ++i)
482             {
483                 if (nodes[i] == addr)
484                 {
485                     return;
486                 }
487             }
488 
489             nodes.push_back(addr);
490         }
491 
add_processing_node(const std::string & ip_or_hostname,unsigned short port)492         void add_processing_node (
493             const std::string& ip_or_hostname,
494             unsigned short port
495         )
496         {
497             add_processing_node(network_address(ip_or_hostname,port));
498         }
499 
get_num_processing_nodes()500         unsigned long get_num_processing_nodes (
501         ) const
502         {
503             return nodes.size();
504         }
505 
remove_processing_nodes()506         void remove_processing_nodes (
507         )
508         {
509             nodes.clear();
510         }
511 
512         template <typename matrix_type>
operator()513         double operator() (
514             const oca& solver,
515             matrix_type& w
516         ) const
517         {
518             // make sure requires clause is not broken
519             DLIB_ASSERT(get_num_processing_nodes() != 0,
520                         "\t double svm_struct_controller_node::operator()"
521                         << "\n\t You must add some processing nodes before calling this function."
522                         << "\n\t this: " << this
523             );
524 
525             problem_type<matrix_type> problem(nodes);
526             problem.set_cache_based_epsilon(cache_based_eps);
527             problem.set_epsilon(eps);
528             problem.set_max_iterations(max_iterations);
529             if (verbose)
530                 problem.be_verbose();
531             problem.set_c(C);
532             for (unsigned long i = 0; i < nuclear_norm_regularizers.size(); ++i)
533             {
534                 problem.add_nuclear_norm_regularizer(
535                     nuclear_norm_regularizers[i].first_dimension,
536                     nuclear_norm_regularizers[i].nr,
537                     nuclear_norm_regularizers[i].nc,
538                     nuclear_norm_regularizers[i].regularization_strength);
539             }
540 
541             return solver(problem, w);
542         }
543 
544         class invalid_problem : public error
545         {
546         public:
invalid_problem(const std::string & a)547             invalid_problem(
548                 const std::string& a
549             ): error(a) {}
550         };
551 
552 
553     private:
554 
555         template <typename matrix_type_>
556         class problem_type : public structural_svm_problem<matrix_type_>
557         {
558         public:
559             typedef typename matrix_type_::type scalar_type;
560             typedef matrix_type_ matrix_type;
561 
problem_type(const std::vector<network_address> & nodes_)562             problem_type (
563                 const std::vector<network_address>& nodes_
564             ) :
565                 nodes(nodes_),
566                 in(3),
567                 num_dims(0)
568             {
569 
570                 // initialize all the transmit pipes
571                 out_pipes.resize(nodes.size());
572                 for (unsigned long i = 0; i < out_pipes.size(); ++i)
573                 {
574                     out_pipes[i].reset(new pipe<tsu_out>(3));
575                 }
576 
577                 // make bridges that connect to all our remote processing nodes
578                 bridges.resize(nodes.size());
579                 for (unsigned long i = 0; i< bridges.size(); ++i)
580                 {
581                     bridges[i].reset(new bridge(connect_to(nodes[i]),
582                                                 receive(in), transmit(*out_pipes[i])));
583                 }
584 
585 
586 
587                 // The remote processing nodes are supposed to all send the problem dimensionality
588                 // upon connection. So get that and make sure everyone agrees on what it's supposed to be.
589                 tsu_in temp;
590                 unsigned long responses = 0;
591                 bool seen_dim = false;
592                 while (responses < nodes.size())
593                 {
594                     in.dequeue(temp);
595                     if (temp.template contains<long>())
596                     {
597                         ++responses;
598                         // if this new dimension doesn't match what we have seen previously
599                         if (seen_dim && num_dims != temp.template get<long>())
600                         {
601                             throw invalid_problem("remote hosts disagree on the number of dimensions!");
602                         }
603                         seen_dim = true;
604                         num_dims = temp.template get<long>();
605                     }
606                 }
607             }
608 
609             // These functions are just here because the structural_svm_problem requires
610             // them, but since we are overloading get_risk() they are never called so they
611             // don't matter.
get_num_samples()612             virtual long get_num_samples () const {return 0;}
get_truth_joint_feature_vector(long,matrix_type &)613             virtual void get_truth_joint_feature_vector ( long , matrix_type&  ) const {}
separation_oracle(const long,const matrix_type &,scalar_type &,matrix_type &)614             virtual void separation_oracle ( const long , const matrix_type& , scalar_type& , matrix_type& ) const {}
615 
get_num_dimensions()616             virtual long get_num_dimensions (
617             ) const
618             {
619                 return num_dims;
620             }
621 
get_risk(matrix_type & w,scalar_type & risk,matrix_type & subgradient)622             virtual void get_risk (
623                 matrix_type& w,
624                 scalar_type& risk,
625                 matrix_type& subgradient
626             ) const
627             {
628                 using namespace impl;
629                 subgradient.set_size(w.size(),1);
630                 subgradient = 0;
631 
632                 // send out all the oracle requests
633                 tsu_out temp_out;
634                 for (unsigned long i = 0; i < out_pipes.size(); ++i)
635                 {
636                     temp_out.template get<oracle_request<matrix_type> >().current_solution = w;
637                     temp_out.template get<oracle_request<matrix_type> >().saved_current_risk_gap = this->saved_current_risk_gap;
638                     temp_out.template get<oracle_request<matrix_type> >().skip_cache = this->skip_cache;
639                     temp_out.template get<oracle_request<matrix_type> >().converged = this->converged;
640                     out_pipes[i]->enqueue(temp_out);
641                 }
642 
643                 // collect all the oracle responses
644                 long num = 0;
645                 scalar_type total_loss = 0;
646                 tsu_in temp_in;
647                 unsigned long responses = 0;
648                 while (responses < out_pipes.size())
649                 {
650                     in.dequeue(temp_in);
651                     if (temp_in.template contains<oracle_response<matrix_type> >())
652                     {
653                         ++responses;
654                         const oracle_response<matrix_type>& data = temp_in.template get<oracle_response<matrix_type> >();
655                         subgradient += data.subgradient;
656                         total_loss += data.loss;
657                         num += data.num;
658                     }
659                 }
660 
661                 subgradient /= num;
662                 total_loss /= num;
663                 risk = total_loss + dot(subgradient,w);
664 
665                 if (this->nuclear_norm_regularizers.size() != 0)
666                 {
667                     matrix_type grad;
668                     double obj;
669                     this->compute_nuclear_norm_parts(w, grad, obj);
670                     risk += obj;
671                     subgradient += grad;
672                 }
673             }
674 
675             std::vector<network_address> nodes;
676 
677             typedef type_safe_union<impl::oracle_request<matrix_type> > tsu_out;
678             typedef type_safe_union<impl::oracle_response<matrix_type>, long> tsu_in;
679 
680             std::vector<std::shared_ptr<pipe<tsu_out> > > out_pipes;
681             mutable pipe<tsu_in> in;
682             std::vector<std::shared_ptr<bridge> > bridges;
683             long num_dims;
684         };
685 
686         std::vector<network_address> nodes;
687         double eps;
688         unsigned long max_iterations;
689         double cache_based_eps;
690         bool verbose;
691         double C;
692         std::vector<impl::nuclear_norm_regularizer> nuclear_norm_regularizers;
693     };
694 
695 // ----------------------------------------------------------------------------------------
696 
697 }
698 
699 #endif // DLIB_STRUCTURAL_SVM_DISTRIBUTeD_Hh_
700 
701