1 // Copyright (C) 2012 Davis E. King (davis@dlib.net) 2 // License: Boost Software License See LICENSE.txt for the full license. 3 #ifndef DLIB_STRUCTURAL_SVM_GRAPH_LAbELING_PROBLEM_Hh_ 4 #define DLIB_STRUCTURAL_SVM_GRAPH_LAbELING_PROBLEM_Hh_ 5 6 7 #include "structural_svm_graph_labeling_problem_abstract.h" 8 #include "../graph_cuts.h" 9 #include "../matrix.h" 10 #include "../array.h" 11 #include <vector> 12 #include <iterator> 13 #include "structural_svm_problem_threaded.h" 14 #include "../graph.h" 15 #include "sparse_vector.h" 16 #include <sstream> 17 18 // ---------------------------------------------------------------------------------------- 19 20 namespace dlib 21 { 22 23 // ---------------------------------------------------------------------------------------- 24 25 template < 26 typename graph_type 27 > is_graph_labeling_problem(const dlib::array<graph_type> & samples,const std::vector<std::vector<bool>> & labels,std::string & reason_for_failure)28 bool is_graph_labeling_problem ( 29 const dlib::array<graph_type>& samples, 30 const std::vector<std::vector<bool> >& labels, 31 std::string& reason_for_failure 32 ) 33 { 34 typedef typename graph_type::type node_vector_type; 35 typedef typename graph_type::edge_type edge_vector_type; 36 // The graph must use all dense vectors or all sparse vectors. It can't mix the two types together. 37 COMPILE_TIME_ASSERT( (is_matrix<node_vector_type>::value && is_matrix<edge_vector_type>::value) || 38 (!is_matrix<node_vector_type>::value && !is_matrix<edge_vector_type>::value)); 39 40 41 std::ostringstream sout; 42 reason_for_failure.clear(); 43 44 if (!is_learning_problem(samples, labels)) 45 { 46 reason_for_failure = "is_learning_problem(samples, labels) returned false."; 47 return false; 48 } 49 50 const bool ismat = is_matrix<typename graph_type::type>::value; 51 52 // these are -1 until assigned with a value 53 long node_dims = -1; 54 long edge_dims = -1; 55 56 for (unsigned long i = 0; i < samples.size(); ++i) 57 { 58 if (samples[i].number_of_nodes() != labels[i].size()) 59 { 60 sout << "samples["<<i<<"].number_of_nodes() doesn't match labels["<<i<<"].size()."; 61 reason_for_failure = sout.str(); 62 return false; 63 } 64 if (graph_contains_length_one_cycle(samples[i])) 65 { 66 sout << "graph_contains_length_one_cycle(samples["<<i<<"]) returned true."; 67 reason_for_failure = sout.str(); 68 return false; 69 } 70 71 for (unsigned long j = 0; j < samples[i].number_of_nodes(); ++j) 72 { 73 if (ismat && samples[i].node(j).data.size() == 0) 74 { 75 sout << "A graph contains an empty vector at node: samples["<<i<<"].node("<<j<<").data."; 76 reason_for_failure = sout.str(); 77 return false; 78 } 79 80 if (ismat && node_dims == -1) 81 node_dims = samples[i].node(j).data.size(); 82 // all nodes must have vectors of the same size. 83 if (ismat && (long)samples[i].node(j).data.size() != node_dims) 84 { 85 sout << "Not all node vectors in samples["<<i<<"] are the same dimension."; 86 reason_for_failure = sout.str(); 87 return false; 88 } 89 90 for (unsigned long n = 0; n < samples[i].node(j).number_of_neighbors(); ++n) 91 { 92 if (ismat && samples[i].node(j).edge(n).size() == 0) 93 { 94 sout << "A graph contains an empty vector at edge: samples["<<i<<"].node("<<j<<").edge("<<n<<")."; 95 reason_for_failure = sout.str(); 96 return false; 97 } 98 if (min(samples[i].node(j).edge(n)) < 0) 99 { 100 sout << "A graph contains negative values on an edge vector at: samples["<<i<<"].node("<<j<<").edge("<<n<<")."; 101 reason_for_failure = sout.str(); 102 return false; 103 } 104 105 if (ismat && edge_dims == -1) 106 edge_dims = samples[i].node(j).edge(n).size(); 107 // all edges must have vectors of the same size. 108 if (ismat && (long)samples[i].node(j).edge(n).size() != edge_dims) 109 { 110 sout << "Not all edge vectors in samples["<<i<<"] are the same dimension."; 111 reason_for_failure = sout.str(); 112 return false; 113 } 114 } 115 } 116 } 117 118 return true; 119 } 120 121 template < 122 typename graph_type 123 > is_graph_labeling_problem(const dlib::array<graph_type> & samples,const std::vector<std::vector<bool>> & labels)124 bool is_graph_labeling_problem ( 125 const dlib::array<graph_type>& samples, 126 const std::vector<std::vector<bool> >& labels 127 ) 128 { 129 std::string reason_for_failure; 130 return is_graph_labeling_problem(samples, labels, reason_for_failure); 131 } 132 133 // ---------------------------------------------------------------------------------------- 134 135 template < 136 typename T, 137 typename U 138 > sizes_match(const std::vector<std::vector<T>> & lhs,const std::vector<std::vector<U>> & rhs)139 bool sizes_match ( 140 const std::vector<std::vector<T> >& lhs, 141 const std::vector<std::vector<U> >& rhs 142 ) 143 { 144 if (lhs.size() != rhs.size()) 145 return false; 146 147 for (unsigned long i = 0; i < lhs.size(); ++i) 148 { 149 if (lhs[i].size() != rhs[i].size()) 150 return false; 151 } 152 153 return true; 154 } 155 156 // ---------------------------------------------------------------------------------------- 157 all_values_are_nonnegative(const std::vector<std::vector<double>> & x)158 inline bool all_values_are_nonnegative ( 159 const std::vector<std::vector<double> >& x 160 ) 161 { 162 for (unsigned long i = 0; i < x.size(); ++i) 163 { 164 for (unsigned long j = 0; j < x[i].size(); ++j) 165 { 166 if (x[i][j] < 0) 167 return false; 168 } 169 } 170 return true; 171 } 172 173 // ---------------------------------------------------------------------------------------- 174 // ---------------------------------------------------------------------------------------- 175 176 namespace impl 177 { 178 template < 179 typename T, 180 typename enable = void 181 > 182 struct fvect 183 { 184 // In this case type should be some sparse vector type 185 typedef typename T::type type; 186 }; 187 188 template < typename T > 189 struct fvect<T, typename enable_if<is_matrix<typename T::type> >::type> 190 { 191 // The point of this stuff is to create the proper matrix 192 // type to represent the concatenation of an edge vector 193 // with an node vector. 194 typedef typename T::type node_mat; 195 typedef typename T::edge_type edge_mat; 196 const static long NRd = node_mat::NR; 197 const static long NRe = edge_mat::NR; 198 const static long NR = ((NRd!=0) && (NRe!=0)) ? (NRd+NRe) : 0; 199 typedef typename node_mat::value_type value_type; 200 201 typedef matrix<value_type,NR,1, typename node_mat::mem_manager_type, typename node_mat::layout_type> type; 202 }; 203 } 204 205 // ---------------------------------------------------------------------------------------- 206 207 template < 208 typename graph_type 209 > 210 class structural_svm_graph_labeling_problem : noncopyable, 211 public structural_svm_problem_threaded<matrix<double,0,1>, 212 typename dlib::impl::fvect<graph_type>::type > 213 { 214 public: 215 typedef matrix<double,0,1> matrix_type; 216 typedef typename dlib::impl::fvect<graph_type>::type feature_vector_type; 217 218 typedef graph_type sample_type; 219 220 typedef std::vector<bool> label_type; 221 222 structural_svm_graph_labeling_problem( 223 const dlib::array<sample_type>& samples_, 224 const std::vector<label_type>& labels_, 225 const std::vector<std::vector<double> >& losses_, 226 unsigned long num_threads = 2 227 ) : 228 structural_svm_problem_threaded<matrix_type,feature_vector_type>(num_threads), 229 samples(samples_), 230 labels(labels_), 231 losses(losses_) 232 { 233 // make sure requires clause is not broken 234 #ifdef ENABLE_ASSERTS 235 std::string reason_for_failure; 236 DLIB_ASSERT(is_graph_labeling_problem(samples, labels, reason_for_failure) == true , 237 "\t structural_svm_graph_labeling_problem::structural_svm_graph_labeling_problem()" 238 << "\n\t Invalid inputs were given to this function." 239 << "\n\t reason_for_failure: " << reason_for_failure 240 << "\n\t samples.size(): " << samples.size() 241 << "\n\t labels.size(): " << labels.size() 242 << "\n\t this: " << this ); 243 DLIB_ASSERT((losses.size() == 0 || sizes_match(labels, losses) == true) && 244 all_values_are_nonnegative(losses) == true, 245 "\t structural_svm_graph_labeling_problem::structural_svm_graph_labeling_problem()" 246 << "\n\t Invalid inputs were given to this function." 247 << "\n\t labels.size(): " << labels.size() 248 << "\n\t losses.size(): " << losses.size() 249 << "\n\t sizes_match(labels,losses): " << sizes_match(labels,losses) 250 << "\n\t all_values_are_nonnegative(losses): " << all_values_are_nonnegative(losses) 251 << "\n\t this: " << this ); 252 #endif 253 254 loss_pos = 1.0; 255 loss_neg = 1.0; 256 257 // figure out how many dimensions are in the node and edge vectors. 258 node_dims = 0; 259 edge_dims = 0; 260 for (unsigned long i = 0; i < samples.size(); ++i) 261 { 262 for (unsigned long j = 0; j < samples[i].number_of_nodes(); ++j) 263 { 264 node_dims = std::max(node_dims,(long)max_index_plus_one(samples[i].node(j).data)); 265 for (unsigned long n = 0; n < samples[i].node(j).number_of_neighbors(); ++n) 266 { 267 edge_dims = std::max(edge_dims, (long)max_index_plus_one(samples[i].node(j).edge(n))); 268 } 269 } 270 } 271 } 272 273 const std::vector<std::vector<double> >& get_losses ( 274 ) const { return losses; } 275 276 long get_num_edge_weights ( 277 ) const 278 { 279 return edge_dims; 280 } 281 282 void set_loss_on_positive_class ( 283 double loss 284 ) 285 { 286 // make sure requires clause is not broken 287 DLIB_ASSERT(loss >= 0 && get_losses().size() == 0, 288 "\t void structural_svm_graph_labeling_problem::set_loss_on_positive_class()" 289 << "\n\t Invalid inputs were given to this function." 290 << "\n\t loss: " << loss 291 << "\n\t this: " << this ); 292 293 loss_pos = loss; 294 } 295 296 void set_loss_on_negative_class ( 297 double loss 298 ) 299 { 300 // make sure requires clause is not broken 301 DLIB_ASSERT(loss >= 0 && get_losses().size() == 0, 302 "\t void structural_svm_graph_labeling_problem::set_loss_on_negative_class()" 303 << "\n\t Invalid inputs were given to this function." 304 << "\n\t loss: " << loss 305 << "\n\t this: " << this ); 306 307 loss_neg = loss; 308 } 309 310 double get_loss_on_negative_class ( 311 ) const 312 { 313 // make sure requires clause is not broken 314 DLIB_ASSERT(get_losses().size() == 0, 315 "\t double structural_svm_graph_labeling_problem::get_loss_on_negative_class()" 316 << "\n\t Invalid inputs were given to this function." 317 << "\n\t this: " << this ); 318 319 return loss_neg; 320 } 321 322 double get_loss_on_positive_class ( 323 ) const 324 { 325 // make sure requires clause is not broken 326 DLIB_ASSERT(get_losses().size() == 0, 327 "\t double structural_svm_graph_labeling_problem::get_loss_on_positive_class()" 328 << "\n\t Invalid inputs were given to this function." 329 << "\n\t this: " << this ); 330 331 return loss_pos; 332 } 333 334 335 private: 336 virtual long get_num_dimensions ( 337 ) const 338 { 339 // The psi/w vector will begin with all the edge dims and then follow with the node dims. 340 return edge_dims + node_dims; 341 } 342 343 virtual long get_num_samples ( 344 ) const 345 { 346 return samples.size(); 347 } 348 349 template <typename psi_type> 350 typename enable_if<is_matrix<psi_type> >::type get_joint_feature_vector ( 351 const sample_type& sample, 352 const label_type& label, 353 psi_type& psi 354 ) const 355 { 356 psi.set_size(get_num_dimensions()); 357 psi = 0; 358 for (unsigned long i = 0; i < sample.number_of_nodes(); ++i) 359 { 360 // accumulate the node vectors 361 if (label[i] == true) 362 set_rowm(psi, range(edge_dims, psi.size()-1)) += sample.node(i).data; 363 364 for (unsigned long n = 0; n < sample.node(i).number_of_neighbors(); ++n) 365 { 366 const unsigned long j = sample.node(i).neighbor(n).index(); 367 368 // Don't double count edges. Also only include the vector if 369 // the labels disagree. 370 if (i < j && label[i] != label[j]) 371 { 372 set_rowm(psi, range(0, edge_dims-1)) -= sample.node(i).edge(n); 373 } 374 } 375 } 376 } 377 378 template <typename T> 379 void add_to_sparse_vect ( 380 T& psi, 381 const T& vect, 382 unsigned long offset 383 ) const 384 { 385 for (typename T::const_iterator i = vect.begin(); i != vect.end(); ++i) 386 { 387 psi.insert(psi.end(), std::make_pair(i->first+offset, i->second)); 388 } 389 } 390 391 template <typename T> 392 void subtract_from_sparse_vect ( 393 T& psi, 394 const T& vect 395 ) const 396 { 397 for (typename T::const_iterator i = vect.begin(); i != vect.end(); ++i) 398 { 399 psi.insert(psi.end(), std::make_pair(i->first, -i->second)); 400 } 401 } 402 403 template <typename psi_type> 404 typename disable_if<is_matrix<psi_type> >::type get_joint_feature_vector ( 405 const sample_type& sample, 406 const label_type& label, 407 psi_type& psi 408 ) const 409 { 410 psi.clear(); 411 for (unsigned long i = 0; i < sample.number_of_nodes(); ++i) 412 { 413 // accumulate the node vectors 414 if (label[i] == true) 415 add_to_sparse_vect(psi, sample.node(i).data, edge_dims); 416 417 for (unsigned long n = 0; n < sample.node(i).number_of_neighbors(); ++n) 418 { 419 const unsigned long j = sample.node(i).neighbor(n).index(); 420 421 // Don't double count edges. Also only include the vector if 422 // the labels disagree. 423 if (i < j && label[i] != label[j]) 424 { 425 subtract_from_sparse_vect(psi, sample.node(i).edge(n)); 426 } 427 } 428 } 429 } 430 431 virtual void get_truth_joint_feature_vector ( 432 long idx, 433 feature_vector_type& psi 434 ) const 435 { 436 get_joint_feature_vector(samples[idx], labels[idx], psi); 437 } 438 439 virtual void separation_oracle ( 440 const long idx, 441 const matrix_type& current_solution, 442 double& loss, 443 feature_vector_type& psi 444 ) const 445 { 446 const sample_type& samp = samples[idx]; 447 448 // setup the potts graph based on samples[idx] and current_solution. 449 graph<double,double>::kernel_1a g; 450 copy_graph_structure(samp, g); 451 for (unsigned long i = 0; i < g.number_of_nodes(); ++i) 452 { 453 g.node(i).data = dot(rowm(current_solution,range(edge_dims,current_solution.size()-1)), 454 samp.node(i).data); 455 456 // Include a loss augmentation so that we will get the proper loss augmented 457 // max when we use find_max_factor_graph_potts() below. 458 if (labels[idx][i]) 459 g.node(i).data -= get_loss_for_sample(idx,i,!labels[idx][i]); 460 else 461 g.node(i).data += get_loss_for_sample(idx,i,!labels[idx][i]); 462 463 for (unsigned long n = 0; n < g.node(i).number_of_neighbors(); ++n) 464 { 465 const unsigned long j = g.node(i).neighbor(n).index(); 466 // Don't compute an edge weight more than once. 467 if (i < j) 468 { 469 g.node(i).edge(n) = dot(rowm(current_solution,range(0,edge_dims-1)), 470 samp.node(i).edge(n)); 471 } 472 } 473 474 } 475 476 std::vector<node_label> labeling; 477 find_max_factor_graph_potts(g, labeling); 478 479 480 std::vector<bool> bool_labeling; 481 bool_labeling.reserve(labeling.size()); 482 // figure out the loss 483 loss = 0; 484 for (unsigned long i = 0; i < labeling.size(); ++i) 485 { 486 const bool predicted_label = (labeling[i]!= 0); 487 bool_labeling.push_back(predicted_label); 488 loss += get_loss_for_sample(idx, i, predicted_label); 489 } 490 491 // compute psi 492 get_joint_feature_vector(samp, bool_labeling, psi); 493 } 494 495 double get_loss_for_sample ( 496 long sample_idx, 497 long node_idx, 498 bool predicted_label 499 ) const 500 /*! 501 requires 502 - 0 <= sample_idx < labels.size() 503 - 0 <= node_idx < labels[sample_idx].size() 504 ensures 505 - returns the loss incurred for predicting that the node 506 samples[sample_idx].node(node_idx) has a label of predicted_label. 507 !*/ 508 { 509 const bool true_label = labels[sample_idx][node_idx]; 510 if (true_label != predicted_label) 511 { 512 if (losses.size() != 0) 513 return losses[sample_idx][node_idx]; 514 else if (true_label == true) 515 return loss_pos; 516 else 517 return loss_neg; 518 } 519 else 520 { 521 // no loss for making the correct prediction. 522 return 0; 523 } 524 } 525 526 const dlib::array<sample_type>& samples; 527 const std::vector<label_type>& labels; 528 const std::vector<std::vector<double> >& losses; 529 530 long node_dims; 531 long edge_dims; 532 double loss_pos; 533 double loss_neg; 534 }; 535 536 // ---------------------------------------------------------------------------------------- 537 538 } 539 540 #endif // DLIB_STRUCTURAL_SVM_GRAPH_LAbELING_PROBLEM_Hh_ 541 542 543