1 // Copyright (C) 2011 Davis E. King (davis@dlib.net) 2 // License: Boost Software License See LICENSE.txt for the full license. 3 #ifndef DLIB_STRUCTURAL_SVM_DISTRIBUTeD_Hh_ 4 #define DLIB_STRUCTURAL_SVM_DISTRIBUTeD_Hh_ 5 6 #include <memory> 7 #include <iostream> 8 #include <vector> 9 10 #include "structural_svm_distributed_abstract.h" 11 #include "structural_svm_problem.h" 12 #include "../bridge.h" 13 #include "../misc_api.h" 14 #include "../statistics.h" 15 #include "../threads.h" 16 #include "../pipe.h" 17 #include "../type_safe_union.h" 18 19 20 namespace dlib 21 { 22 23 // ---------------------------------------------------------------------------------------- 24 25 namespace impl 26 { 27 28 template <typename matrix_type> 29 struct oracle_response 30 { 31 typedef typename matrix_type::type scalar_type; 32 33 matrix_type subgradient; 34 scalar_type loss; 35 long num; 36 swaporacle_response37 friend void swap (oracle_response& a, oracle_response& b) 38 { 39 a.subgradient.swap(b.subgradient); 40 std::swap(a.loss, b.loss); 41 std::swap(a.num, b.num); 42 } 43 serializeoracle_response44 friend void serialize (const oracle_response& item, std::ostream& out) 45 { 46 serialize(item.subgradient, out); 47 dlib::serialize(item.loss, out); 48 dlib::serialize(item.num, out); 49 } 50 deserializeoracle_response51 friend void deserialize (oracle_response& item, std::istream& in) 52 { 53 deserialize(item.subgradient, in); 54 dlib::deserialize(item.loss, in); 55 dlib::deserialize(item.num, in); 56 } 57 }; 58 59 // ---------------------------------------------------------------------------------------- 60 61 template <typename matrix_type> 62 struct oracle_request 63 { 64 typedef typename matrix_type::type scalar_type; 65 66 matrix_type current_solution; 67 scalar_type saved_current_risk_gap; 68 bool skip_cache; 69 bool converged; 70 swaporacle_request71 friend void swap (oracle_request& a, oracle_request& b) 72 { 73 a.current_solution.swap(b.current_solution); 74 std::swap(a.saved_current_risk_gap, b.saved_current_risk_gap); 75 std::swap(a.skip_cache, b.skip_cache); 76 std::swap(a.converged, b.converged); 77 } 78 serializeoracle_request79 friend void serialize (const oracle_request& item, std::ostream& out) 80 { 81 serialize(item.current_solution, out); 82 dlib::serialize(item.saved_current_risk_gap, out); 83 dlib::serialize(item.skip_cache, out); 84 dlib::serialize(item.converged, out); 85 } 86 deserializeoracle_request87 friend void deserialize (oracle_request& item, std::istream& in) 88 { 89 deserialize(item.current_solution, in); 90 dlib::deserialize(item.saved_current_risk_gap, in); 91 dlib::deserialize(item.skip_cache, in); 92 dlib::deserialize(item.converged, in); 93 } 94 }; 95 96 } 97 98 // ---------------------------------------------------------------------------------------- 99 100 class svm_struct_processing_node : noncopyable 101 { 102 public: 103 104 template < 105 typename T, 106 typename U 107 > svm_struct_processing_node(const structural_svm_problem<T,U> & problem,unsigned short port,unsigned short num_threads)108 svm_struct_processing_node ( 109 const structural_svm_problem<T,U>& problem, 110 unsigned short port, 111 unsigned short num_threads 112 ) 113 { 114 // make sure requires clause is not broken 115 DLIB_ASSERT(port != 0 && problem.get_num_samples() != 0 && 116 problem.get_num_dimensions() != 0, 117 "\t svm_struct_processing_node()" 118 << "\n\t Invalid arguments were given to this function" 119 << "\n\t port: " << port 120 << "\n\t problem.get_num_samples(): " << problem.get_num_samples() 121 << "\n\t problem.get_num_dimensions(): " << problem.get_num_dimensions() 122 << "\n\t this: " << this 123 ); 124 125 the_problem.reset(new node_type<T,U>(problem, port, num_threads)); 126 } 127 128 private: 129 130 struct base 131 { ~basebase132 virtual ~base(){} 133 }; 134 135 template < 136 typename matrix_type, 137 typename feature_vector_type 138 > 139 class node_type : public base, threaded_object 140 { 141 public: 142 typedef typename matrix_type::type scalar_type; 143 node_type(const structural_svm_problem<matrix_type,feature_vector_type> & prob,unsigned short port,unsigned long num_threads)144 node_type( 145 const structural_svm_problem<matrix_type,feature_vector_type>& prob, 146 unsigned short port, 147 unsigned long num_threads 148 ) : in(3),out(3), problem(prob), tp(num_threads) 149 { 150 b.reconfigure(listen_on_port(port), receive(in), transmit(out)); 151 152 start(); 153 } 154 ~node_type()155 ~node_type() 156 { 157 in.disable(); 158 out.disable(); 159 wait(); 160 } 161 162 private: 163 thread()164 void thread() 165 { 166 using namespace impl; 167 tsu_in msg; 168 tsu_out temp; 169 170 timestamper ts; 171 running_stats<double> with_buffer_time; 172 running_stats<double> without_buffer_time; 173 unsigned long num_iterations_executed = 0; 174 175 while (in.dequeue(msg)) 176 { 177 // initialize the cache and compute psi_true. 178 if (cache.size() == 0) 179 { 180 cache.resize(problem.get_num_samples()); 181 for (unsigned long i = 0; i < cache.size(); ++i) 182 cache[i].init(&problem,i); 183 184 psi_true.set_size(problem.get_num_dimensions(),1); 185 psi_true = 0; 186 187 const unsigned long num = problem.get_num_samples(); 188 feature_vector_type ftemp; 189 for (unsigned long i = 0; i < num; ++i) 190 { 191 cache[i].get_truth_joint_feature_vector_cached(ftemp); 192 193 subtract_from(psi_true, ftemp); 194 } 195 } 196 197 198 if (msg.template contains<bridge_status>() && 199 msg.template get<bridge_status>().is_connected) 200 { 201 temp = problem.get_num_dimensions(); 202 out.enqueue(temp); 203 204 } 205 else if (msg.template contains<oracle_request<matrix_type> >()) 206 { 207 ++num_iterations_executed; 208 209 const oracle_request<matrix_type>& req = msg.template get<oracle_request<matrix_type> >(); 210 211 oracle_response<matrix_type>& data = temp.template get<oracle_response<matrix_type> >(); 212 213 data.subgradient = psi_true; 214 data.loss = 0; 215 216 data.num = problem.get_num_samples(); 217 218 const uint64 start_time = ts.get_timestamp(); 219 220 // pick fastest buffering strategy 221 bool buffer_subgradients_locally = with_buffer_time.mean() < without_buffer_time.mean(); 222 223 // every 50 iterations we should try to flip the buffering scheme to see if 224 // doing it the other way might be better. 225 if ((num_iterations_executed%50) == 0) 226 { 227 buffer_subgradients_locally = !buffer_subgradients_locally; 228 } 229 230 binder b(*this, req, data, buffer_subgradients_locally); 231 parallel_for_blocked(tp, 0, data.num, b, &binder::call_oracle); 232 233 const uint64 stop_time = ts.get_timestamp(); 234 if (buffer_subgradients_locally) 235 with_buffer_time.add(stop_time-start_time); 236 else 237 without_buffer_time.add(stop_time-start_time); 238 239 out.enqueue(temp); 240 } 241 } 242 } 243 244 struct binder 245 { binderbinder246 binder ( 247 const node_type& self_, 248 const impl::oracle_request<matrix_type>& req_, 249 impl::oracle_response<matrix_type>& data_, 250 bool buffer_subgradients_locally_ 251 ) : self(self_), req(req_), data(data_), 252 buffer_subgradients_locally(buffer_subgradients_locally_) {} 253 call_oraclebinder254 void call_oracle ( 255 long begin, 256 long end 257 ) 258 { 259 // If we are only going to call the separation oracle once then don't 260 // run the slightly more complex for loop version of this code. Or if 261 // we just don't want to run the complex buffering one. The code later 262 // on decides if we should do the buffering based on how long it takes 263 // to execute. We do this because, when the subgradient is really high 264 // dimensional it can take a lot of time to add them together. So we 265 // might want to avoid doing that. 266 if (end-begin <= 1 || !buffer_subgradients_locally) 267 { 268 scalar_type loss; 269 feature_vector_type ftemp; 270 for (long i = begin; i < end; ++i) 271 { 272 self.cache[i].separation_oracle_cached(req.converged, 273 req.skip_cache, 274 req.saved_current_risk_gap, 275 req.current_solution, 276 loss, 277 ftemp); 278 279 auto_mutex lock(self.accum_mutex); 280 data.loss += loss; 281 add_to(data.subgradient, ftemp); 282 } 283 } 284 else 285 { 286 scalar_type loss = 0; 287 matrix_type faccum(data.subgradient.size(),1); 288 faccum = 0; 289 290 feature_vector_type ftemp; 291 292 for (long i = begin; i < end; ++i) 293 { 294 scalar_type loss_temp; 295 self.cache[i].separation_oracle_cached(req.converged, 296 req.skip_cache, 297 req.saved_current_risk_gap, 298 req.current_solution, 299 loss_temp, 300 ftemp); 301 loss += loss_temp; 302 add_to(faccum, ftemp); 303 } 304 305 auto_mutex lock(self.accum_mutex); 306 data.loss += loss; 307 add_to(data.subgradient, faccum); 308 } 309 } 310 311 const node_type& self; 312 const impl::oracle_request<matrix_type>& req; 313 impl::oracle_response<matrix_type>& data; 314 bool buffer_subgradients_locally; 315 }; 316 317 318 319 typedef type_safe_union<impl::oracle_request<matrix_type>, bridge_status> tsu_in; 320 typedef type_safe_union<impl::oracle_response<matrix_type> , long> tsu_out; 321 322 pipe<tsu_in> in; 323 pipe<tsu_out> out; 324 bridge b; 325 326 mutable matrix_type psi_true; 327 const structural_svm_problem<matrix_type,feature_vector_type>& problem; 328 mutable std::vector<cache_element_structural_svm<structural_svm_problem<matrix_type,feature_vector_type> > > cache; 329 330 mutable thread_pool tp; 331 mutex accum_mutex; 332 }; 333 334 335 std::unique_ptr<base> the_problem; 336 }; 337 338 // ---------------------------------------------------------------------------------------- 339 340 class svm_struct_controller_node : noncopyable 341 { 342 public: 343 svm_struct_controller_node()344 svm_struct_controller_node ( 345 ) : 346 eps(0.001), 347 max_iterations(10000), 348 cache_based_eps(std::numeric_limits<double>::infinity()), 349 verbose(false), 350 C(1) 351 {} 352 get_cache_based_epsilon()353 double get_cache_based_epsilon ( 354 ) const 355 { 356 return cache_based_eps; 357 } 358 set_cache_based_epsilon(double eps_)359 void set_cache_based_epsilon ( 360 double eps_ 361 ) 362 { 363 // make sure requires clause is not broken 364 DLIB_ASSERT(eps_ > 0, 365 "\t void svm_struct_controller_node::set_cache_based_epsilon()" 366 << "\n\t eps_ must be greater than 0" 367 << "\n\t eps_: " << eps_ 368 << "\n\t this: " << this 369 ); 370 371 cache_based_eps = eps_; 372 } 373 set_epsilon(double eps_)374 void set_epsilon ( 375 double eps_ 376 ) 377 { 378 // make sure requires clause is not broken 379 DLIB_ASSERT(eps_ > 0, 380 "\t void svm_struct_controller_node::set_epsilon()" 381 << "\n\t eps_ must be greater than 0" 382 << "\n\t eps_: " << eps_ 383 << "\n\t this: " << this 384 ); 385 386 eps = eps_; 387 } 388 get_epsilon()389 double get_epsilon ( 390 ) const { return eps; } 391 get_max_iterations()392 unsigned long get_max_iterations ( 393 ) const { return max_iterations; } 394 set_max_iterations(unsigned long max_iter)395 void set_max_iterations ( 396 unsigned long max_iter 397 ) 398 { 399 max_iterations = max_iter; 400 } 401 be_verbose()402 void be_verbose ( 403 ) 404 { 405 verbose = true; 406 } 407 be_quiet()408 void be_quiet( 409 ) 410 { 411 verbose = false; 412 } 413 add_nuclear_norm_regularizer(long first_dimension,long rows,long cols,double regularization_strength)414 void add_nuclear_norm_regularizer ( 415 long first_dimension, 416 long rows, 417 long cols, 418 double regularization_strength 419 ) 420 { 421 // make sure requires clause is not broken 422 DLIB_ASSERT(0 <= first_dimension && 423 0 <= rows && 0 <= cols && 424 0 < regularization_strength, 425 "\t void svm_struct_controller_node::add_nuclear_norm_regularizer()" 426 << "\n\t Invalid arguments were given to this function." 427 << "\n\t first_dimension: " << first_dimension 428 << "\n\t rows: " << rows 429 << "\n\t cols: " << cols 430 << "\n\t regularization_strength: " << regularization_strength 431 << "\n\t this: " << this 432 ); 433 434 impl::nuclear_norm_regularizer temp; 435 temp.first_dimension = first_dimension; 436 temp.nr = rows; 437 temp.nc = cols; 438 temp.regularization_strength = regularization_strength; 439 nuclear_norm_regularizers.push_back(temp); 440 } 441 num_nuclear_norm_regularizers()442 unsigned long num_nuclear_norm_regularizers ( 443 ) const { return nuclear_norm_regularizers.size(); } 444 clear_nuclear_norm_regularizers()445 void clear_nuclear_norm_regularizers ( 446 ) { nuclear_norm_regularizers.clear(); } 447 448 get_c()449 double get_c ( 450 ) const { return C; } 451 set_c(double C_)452 void set_c ( 453 double C_ 454 ) 455 { 456 // make sure requires clause is not broken 457 DLIB_ASSERT(C_ > 0, 458 "\t void svm_struct_controller_node::set_c()" 459 << "\n\t C_ must be greater than 0" 460 << "\n\t C_: " << C_ 461 << "\n\t this: " << this 462 ); 463 464 C = C_; 465 } 466 add_processing_node(const network_address & addr)467 void add_processing_node ( 468 const network_address& addr 469 ) 470 { 471 // make sure requires clause is not broken 472 DLIB_ASSERT(addr.port != 0, 473 "\t void svm_struct_controller_node::add_processing_node()" 474 << "\n\t Invalid inputs were given to this function" 475 << "\n\t addr.host_address: " << addr.host_address 476 << "\n\t addr.port: " << addr.port 477 << "\n\t this: " << this 478 ); 479 480 // check if this address is already registered 481 for (unsigned long i = 0; i < nodes.size(); ++i) 482 { 483 if (nodes[i] == addr) 484 { 485 return; 486 } 487 } 488 489 nodes.push_back(addr); 490 } 491 add_processing_node(const std::string & ip_or_hostname,unsigned short port)492 void add_processing_node ( 493 const std::string& ip_or_hostname, 494 unsigned short port 495 ) 496 { 497 add_processing_node(network_address(ip_or_hostname,port)); 498 } 499 get_num_processing_nodes()500 unsigned long get_num_processing_nodes ( 501 ) const 502 { 503 return nodes.size(); 504 } 505 remove_processing_nodes()506 void remove_processing_nodes ( 507 ) 508 { 509 nodes.clear(); 510 } 511 512 template <typename matrix_type> operator()513 double operator() ( 514 const oca& solver, 515 matrix_type& w 516 ) const 517 { 518 // make sure requires clause is not broken 519 DLIB_ASSERT(get_num_processing_nodes() != 0, 520 "\t double svm_struct_controller_node::operator()" 521 << "\n\t You must add some processing nodes before calling this function." 522 << "\n\t this: " << this 523 ); 524 525 problem_type<matrix_type> problem(nodes); 526 problem.set_cache_based_epsilon(cache_based_eps); 527 problem.set_epsilon(eps); 528 problem.set_max_iterations(max_iterations); 529 if (verbose) 530 problem.be_verbose(); 531 problem.set_c(C); 532 for (unsigned long i = 0; i < nuclear_norm_regularizers.size(); ++i) 533 { 534 problem.add_nuclear_norm_regularizer( 535 nuclear_norm_regularizers[i].first_dimension, 536 nuclear_norm_regularizers[i].nr, 537 nuclear_norm_regularizers[i].nc, 538 nuclear_norm_regularizers[i].regularization_strength); 539 } 540 541 return solver(problem, w); 542 } 543 544 class invalid_problem : public error 545 { 546 public: invalid_problem(const std::string & a)547 invalid_problem( 548 const std::string& a 549 ): error(a) {} 550 }; 551 552 553 private: 554 555 template <typename matrix_type_> 556 class problem_type : public structural_svm_problem<matrix_type_> 557 { 558 public: 559 typedef typename matrix_type_::type scalar_type; 560 typedef matrix_type_ matrix_type; 561 problem_type(const std::vector<network_address> & nodes_)562 problem_type ( 563 const std::vector<network_address>& nodes_ 564 ) : 565 nodes(nodes_), 566 in(3), 567 num_dims(0) 568 { 569 570 // initialize all the transmit pipes 571 out_pipes.resize(nodes.size()); 572 for (unsigned long i = 0; i < out_pipes.size(); ++i) 573 { 574 out_pipes[i].reset(new pipe<tsu_out>(3)); 575 } 576 577 // make bridges that connect to all our remote processing nodes 578 bridges.resize(nodes.size()); 579 for (unsigned long i = 0; i< bridges.size(); ++i) 580 { 581 bridges[i].reset(new bridge(connect_to(nodes[i]), 582 receive(in), transmit(*out_pipes[i]))); 583 } 584 585 586 587 // The remote processing nodes are supposed to all send the problem dimensionality 588 // upon connection. So get that and make sure everyone agrees on what it's supposed to be. 589 tsu_in temp; 590 unsigned long responses = 0; 591 bool seen_dim = false; 592 while (responses < nodes.size()) 593 { 594 in.dequeue(temp); 595 if (temp.template contains<long>()) 596 { 597 ++responses; 598 // if this new dimension doesn't match what we have seen previously 599 if (seen_dim && num_dims != temp.template get<long>()) 600 { 601 throw invalid_problem("remote hosts disagree on the number of dimensions!"); 602 } 603 seen_dim = true; 604 num_dims = temp.template get<long>(); 605 } 606 } 607 } 608 609 // These functions are just here because the structural_svm_problem requires 610 // them, but since we are overloading get_risk() they are never called so they 611 // don't matter. get_num_samples()612 virtual long get_num_samples () const {return 0;} get_truth_joint_feature_vector(long,matrix_type &)613 virtual void get_truth_joint_feature_vector ( long , matrix_type& ) const {} separation_oracle(const long,const matrix_type &,scalar_type &,matrix_type &)614 virtual void separation_oracle ( const long , const matrix_type& , scalar_type& , matrix_type& ) const {} 615 get_num_dimensions()616 virtual long get_num_dimensions ( 617 ) const 618 { 619 return num_dims; 620 } 621 get_risk(matrix_type & w,scalar_type & risk,matrix_type & subgradient)622 virtual void get_risk ( 623 matrix_type& w, 624 scalar_type& risk, 625 matrix_type& subgradient 626 ) const 627 { 628 using namespace impl; 629 subgradient.set_size(w.size(),1); 630 subgradient = 0; 631 632 // send out all the oracle requests 633 tsu_out temp_out; 634 for (unsigned long i = 0; i < out_pipes.size(); ++i) 635 { 636 temp_out.template get<oracle_request<matrix_type> >().current_solution = w; 637 temp_out.template get<oracle_request<matrix_type> >().saved_current_risk_gap = this->saved_current_risk_gap; 638 temp_out.template get<oracle_request<matrix_type> >().skip_cache = this->skip_cache; 639 temp_out.template get<oracle_request<matrix_type> >().converged = this->converged; 640 out_pipes[i]->enqueue(temp_out); 641 } 642 643 // collect all the oracle responses 644 long num = 0; 645 scalar_type total_loss = 0; 646 tsu_in temp_in; 647 unsigned long responses = 0; 648 while (responses < out_pipes.size()) 649 { 650 in.dequeue(temp_in); 651 if (temp_in.template contains<oracle_response<matrix_type> >()) 652 { 653 ++responses; 654 const oracle_response<matrix_type>& data = temp_in.template get<oracle_response<matrix_type> >(); 655 subgradient += data.subgradient; 656 total_loss += data.loss; 657 num += data.num; 658 } 659 } 660 661 subgradient /= num; 662 total_loss /= num; 663 risk = total_loss + dot(subgradient,w); 664 665 if (this->nuclear_norm_regularizers.size() != 0) 666 { 667 matrix_type grad; 668 double obj; 669 this->compute_nuclear_norm_parts(w, grad, obj); 670 risk += obj; 671 subgradient += grad; 672 } 673 } 674 675 std::vector<network_address> nodes; 676 677 typedef type_safe_union<impl::oracle_request<matrix_type> > tsu_out; 678 typedef type_safe_union<impl::oracle_response<matrix_type>, long> tsu_in; 679 680 std::vector<std::shared_ptr<pipe<tsu_out> > > out_pipes; 681 mutable pipe<tsu_in> in; 682 std::vector<std::shared_ptr<bridge> > bridges; 683 long num_dims; 684 }; 685 686 std::vector<network_address> nodes; 687 double eps; 688 unsigned long max_iterations; 689 double cache_based_eps; 690 bool verbose; 691 double C; 692 std::vector<impl::nuclear_norm_regularizer> nuclear_norm_regularizers; 693 }; 694 695 // ---------------------------------------------------------------------------------------- 696 697 } 698 699 #endif // DLIB_STRUCTURAL_SVM_DISTRIBUTeD_Hh_ 700 701