1 // Copyright (C) 2012  Davis E. King (davis@dlib.net)
2 // License: Boost Software License   See LICENSE.txt for the full license.
3 #ifndef DLIB_STRUCTURAL_SVM_GRAPH_LAbELING_PROBLEM_Hh_
4 #define DLIB_STRUCTURAL_SVM_GRAPH_LAbELING_PROBLEM_Hh_
5 
6 
7 #include "structural_svm_graph_labeling_problem_abstract.h"
8 #include "../graph_cuts.h"
9 #include "../matrix.h"
10 #include "../array.h"
11 #include <vector>
12 #include <iterator>
13 #include "structural_svm_problem_threaded.h"
14 #include "../graph.h"
15 #include "sparse_vector.h"
16 #include <sstream>
17 
18 // ----------------------------------------------------------------------------------------
19 
20 namespace dlib
21 {
22 
23 // ----------------------------------------------------------------------------------------
24 
25     template <
26         typename graph_type
27         >
is_graph_labeling_problem(const dlib::array<graph_type> & samples,const std::vector<std::vector<bool>> & labels,std::string & reason_for_failure)28     bool is_graph_labeling_problem (
29         const dlib::array<graph_type>& samples,
30         const std::vector<std::vector<bool> >& labels,
31         std::string& reason_for_failure
32     )
33     {
34         typedef typename graph_type::type node_vector_type;
35         typedef typename graph_type::edge_type edge_vector_type;
36         // The graph must use all dense vectors or all sparse vectors.  It can't mix the two types together.
37         COMPILE_TIME_ASSERT( (is_matrix<node_vector_type>::value && is_matrix<edge_vector_type>::value) ||
38                             (!is_matrix<node_vector_type>::value && !is_matrix<edge_vector_type>::value));
39 
40 
41         std::ostringstream sout;
42         reason_for_failure.clear();
43 
44         if (!is_learning_problem(samples, labels))
45         {
46             reason_for_failure = "is_learning_problem(samples, labels) returned false.";
47             return false;
48         }
49 
50         const bool ismat = is_matrix<typename graph_type::type>::value;
51 
52         // these are -1 until assigned with a value
53         long node_dims = -1;
54         long edge_dims = -1;
55 
56         for (unsigned long i = 0; i < samples.size(); ++i)
57         {
58             if (samples[i].number_of_nodes() != labels[i].size())
59             {
60                 sout << "samples["<<i<<"].number_of_nodes() doesn't match labels["<<i<<"].size().";
61                 reason_for_failure = sout.str();
62                 return false;
63             }
64             if (graph_contains_length_one_cycle(samples[i]))
65             {
66                 sout << "graph_contains_length_one_cycle(samples["<<i<<"]) returned true.";
67                 reason_for_failure = sout.str();
68                 return false;
69             }
70 
71             for (unsigned long j = 0; j < samples[i].number_of_nodes(); ++j)
72             {
73                 if (ismat && samples[i].node(j).data.size() == 0)
74                 {
75                     sout << "A graph contains an empty vector at node: samples["<<i<<"].node("<<j<<").data.";
76                     reason_for_failure = sout.str();
77                     return false;
78                 }
79 
80                 if (ismat && node_dims == -1)
81                     node_dims = samples[i].node(j).data.size();
82                 // all nodes must have vectors of the same size.
83                 if (ismat && (long)samples[i].node(j).data.size() != node_dims)
84                 {
85                     sout << "Not all node vectors in samples["<<i<<"] are the same dimension.";
86                     reason_for_failure = sout.str();
87                     return false;
88                 }
89 
90                 for (unsigned long n = 0; n < samples[i].node(j).number_of_neighbors(); ++n)
91                 {
92                     if (ismat && samples[i].node(j).edge(n).size() == 0)
93                     {
94                         sout << "A graph contains an empty vector at edge: samples["<<i<<"].node("<<j<<").edge("<<n<<").";
95                         reason_for_failure = sout.str();
96                         return false;
97                     }
98                     if (min(samples[i].node(j).edge(n)) < 0)
99                     {
100                         sout << "A graph contains negative values on an edge vector at: samples["<<i<<"].node("<<j<<").edge("<<n<<").";
101                         reason_for_failure = sout.str();
102                         return false;
103                     }
104 
105                     if (ismat && edge_dims == -1)
106                         edge_dims = samples[i].node(j).edge(n).size();
107                     // all edges must have vectors of the same size.
108                     if (ismat && (long)samples[i].node(j).edge(n).size() != edge_dims)
109                     {
110                         sout << "Not all edge vectors in samples["<<i<<"] are the same dimension.";
111                         reason_for_failure = sout.str();
112                         return false;
113                     }
114                 }
115             }
116         }
117 
118         return true;
119     }
120 
121     template <
122         typename graph_type
123         >
is_graph_labeling_problem(const dlib::array<graph_type> & samples,const std::vector<std::vector<bool>> & labels)124     bool is_graph_labeling_problem (
125         const dlib::array<graph_type>& samples,
126         const std::vector<std::vector<bool> >& labels
127     )
128     {
129         std::string reason_for_failure;
130         return is_graph_labeling_problem(samples, labels, reason_for_failure);
131     }
132 
133 // ----------------------------------------------------------------------------------------
134 
135     template <
136         typename T,
137         typename U
138         >
sizes_match(const std::vector<std::vector<T>> & lhs,const std::vector<std::vector<U>> & rhs)139     bool sizes_match (
140         const std::vector<std::vector<T> >& lhs,
141         const std::vector<std::vector<U> >& rhs
142     )
143     {
144         if (lhs.size() != rhs.size())
145             return false;
146 
147         for (unsigned long i = 0; i < lhs.size(); ++i)
148         {
149             if (lhs[i].size() != rhs[i].size())
150                 return false;
151         }
152 
153         return true;
154     }
155 
156 // ----------------------------------------------------------------------------------------
157 
all_values_are_nonnegative(const std::vector<std::vector<double>> & x)158     inline bool all_values_are_nonnegative (
159         const std::vector<std::vector<double> >& x
160     )
161     {
162         for (unsigned long i = 0; i < x.size(); ++i)
163         {
164             for (unsigned long j = 0; j < x[i].size(); ++j)
165             {
166                 if (x[i][j] < 0)
167                     return false;
168             }
169         }
170         return true;
171     }
172 
173 // ----------------------------------------------------------------------------------------
174 // ----------------------------------------------------------------------------------------
175 
176     namespace impl
177     {
178         template <
179             typename T,
180             typename enable = void
181             >
182         struct fvect
183         {
184             // In this case type should be some sparse vector type
185             typedef typename T::type type;
186         };
187 
188         template < typename T >
189         struct fvect<T, typename enable_if<is_matrix<typename T::type> >::type>
190         {
191             // The point of this stuff is to create the proper matrix
192             // type to represent the concatenation of an edge vector
193             // with an node vector.
194             typedef typename T::type      node_mat;
195             typedef typename T::edge_type edge_mat;
196             const static long NRd = node_mat::NR;
197             const static long NRe = edge_mat::NR;
198             const static long NR = ((NRd!=0) && (NRe!=0)) ? (NRd+NRe) : 0;
199             typedef typename node_mat::value_type value_type;
200 
201             typedef matrix<value_type,NR,1, typename node_mat::mem_manager_type, typename node_mat::layout_type> type;
202         };
203     }
204 
205 // ----------------------------------------------------------------------------------------
206 
207     template <
208         typename graph_type
209         >
210     class structural_svm_graph_labeling_problem : noncopyable,
211         public structural_svm_problem_threaded<matrix<double,0,1>,
212                                             typename dlib::impl::fvect<graph_type>::type >
213     {
214     public:
215         typedef matrix<double,0,1> matrix_type;
216         typedef typename dlib::impl::fvect<graph_type>::type feature_vector_type;
217 
218         typedef graph_type sample_type;
219 
220         typedef std::vector<bool> label_type;
221 
222         structural_svm_graph_labeling_problem(
223             const dlib::array<sample_type>& samples_,
224             const std::vector<label_type>& labels_,
225             const std::vector<std::vector<double> >& losses_,
226             unsigned long num_threads = 2
227         ) :
228             structural_svm_problem_threaded<matrix_type,feature_vector_type>(num_threads),
229             samples(samples_),
230             labels(labels_),
231             losses(losses_)
232         {
233             // make sure requires clause is not broken
234 #ifdef ENABLE_ASSERTS
235             std::string reason_for_failure;
236             DLIB_ASSERT(is_graph_labeling_problem(samples, labels, reason_for_failure) == true ,
237                     "\t structural_svm_graph_labeling_problem::structural_svm_graph_labeling_problem()"
238                     << "\n\t Invalid inputs were given to this function."
239                     << "\n\t reason_for_failure: " << reason_for_failure
240                     << "\n\t samples.size(): " << samples.size()
241                     << "\n\t labels.size():  " << labels.size()
242                     << "\n\t this: " << this );
243             DLIB_ASSERT((losses.size() == 0 || sizes_match(labels, losses) == true) &&
244                         all_values_are_nonnegative(losses) == true,
245                     "\t structural_svm_graph_labeling_problem::structural_svm_graph_labeling_problem()"
246                     << "\n\t Invalid inputs were given to this function."
247                     << "\n\t labels.size():  " << labels.size()
248                     << "\n\t losses.size():  " << losses.size()
249                     << "\n\t sizes_match(labels,losses): " << sizes_match(labels,losses)
250                     << "\n\t all_values_are_nonnegative(losses): " << all_values_are_nonnegative(losses)
251                     << "\n\t this: " << this );
252 #endif
253 
254             loss_pos = 1.0;
255             loss_neg = 1.0;
256 
257             // figure out how many dimensions are in the node and edge vectors.
258             node_dims = 0;
259             edge_dims = 0;
260             for (unsigned long i = 0; i < samples.size(); ++i)
261             {
262                 for (unsigned long j = 0; j < samples[i].number_of_nodes(); ++j)
263                 {
264                     node_dims = std::max(node_dims,(long)max_index_plus_one(samples[i].node(j).data));
265                     for (unsigned long n = 0; n < samples[i].node(j).number_of_neighbors(); ++n)
266                     {
267                         edge_dims = std::max(edge_dims, (long)max_index_plus_one(samples[i].node(j).edge(n)));
268                     }
269                 }
270             }
271         }
272 
273         const std::vector<std::vector<double> >& get_losses (
274         ) const { return losses; }
275 
276         long get_num_edge_weights (
277         ) const
278         {
279             return edge_dims;
280         }
281 
282         void set_loss_on_positive_class (
283             double loss
284         )
285         {
286             // make sure requires clause is not broken
287             DLIB_ASSERT(loss >= 0 && get_losses().size() == 0,
288                     "\t void structural_svm_graph_labeling_problem::set_loss_on_positive_class()"
289                     << "\n\t Invalid inputs were given to this function."
290                     << "\n\t loss: " << loss
291                     << "\n\t this: " << this );
292 
293             loss_pos = loss;
294         }
295 
296         void set_loss_on_negative_class (
297             double loss
298         )
299         {
300             // make sure requires clause is not broken
301             DLIB_ASSERT(loss >= 0 && get_losses().size() == 0,
302                     "\t void structural_svm_graph_labeling_problem::set_loss_on_negative_class()"
303                     << "\n\t Invalid inputs were given to this function."
304                     << "\n\t loss: " << loss
305                     << "\n\t this: " << this );
306 
307             loss_neg = loss;
308         }
309 
310         double get_loss_on_negative_class (
311         ) const
312         {
313             // make sure requires clause is not broken
314             DLIB_ASSERT(get_losses().size() == 0,
315                     "\t double structural_svm_graph_labeling_problem::get_loss_on_negative_class()"
316                     << "\n\t Invalid inputs were given to this function."
317                     << "\n\t this: " << this );
318 
319             return loss_neg;
320         }
321 
322         double get_loss_on_positive_class (
323         ) const
324         {
325             // make sure requires clause is not broken
326             DLIB_ASSERT(get_losses().size() == 0,
327                     "\t double structural_svm_graph_labeling_problem::get_loss_on_positive_class()"
328                     << "\n\t Invalid inputs were given to this function."
329                     << "\n\t this: " << this );
330 
331             return loss_pos;
332         }
333 
334 
335     private:
336         virtual long get_num_dimensions (
337         ) const
338         {
339             // The psi/w vector will begin with all the edge dims and then follow with the node dims.
340             return edge_dims + node_dims;
341         }
342 
343         virtual long get_num_samples (
344         ) const
345         {
346             return samples.size();
347         }
348 
349         template <typename psi_type>
350         typename enable_if<is_matrix<psi_type> >::type get_joint_feature_vector (
351             const sample_type& sample,
352             const label_type& label,
353             psi_type& psi
354         ) const
355         {
356             psi.set_size(get_num_dimensions());
357             psi = 0;
358             for (unsigned long i = 0; i < sample.number_of_nodes(); ++i)
359             {
360                 // accumulate the node vectors
361                 if (label[i] == true)
362                     set_rowm(psi, range(edge_dims, psi.size()-1)) += sample.node(i).data;
363 
364                 for (unsigned long n = 0; n < sample.node(i).number_of_neighbors(); ++n)
365                 {
366                     const unsigned long j = sample.node(i).neighbor(n).index();
367 
368                     // Don't double count edges.  Also only include the vector if
369                     // the labels disagree.
370                     if (i < j && label[i] != label[j])
371                     {
372                         set_rowm(psi, range(0, edge_dims-1)) -= sample.node(i).edge(n);
373                     }
374                 }
375             }
376         }
377 
378         template <typename T>
379         void add_to_sparse_vect (
380             T& psi,
381             const T& vect,
382             unsigned long offset
383         ) const
384         {
385             for (typename T::const_iterator i = vect.begin(); i != vect.end(); ++i)
386             {
387                 psi.insert(psi.end(), std::make_pair(i->first+offset, i->second));
388             }
389         }
390 
391         template <typename T>
392         void subtract_from_sparse_vect (
393             T& psi,
394             const T& vect
395         ) const
396         {
397             for (typename T::const_iterator i = vect.begin(); i != vect.end(); ++i)
398             {
399                 psi.insert(psi.end(), std::make_pair(i->first, -i->second));
400             }
401         }
402 
403         template <typename psi_type>
404         typename disable_if<is_matrix<psi_type> >::type get_joint_feature_vector (
405             const sample_type& sample,
406             const label_type& label,
407             psi_type& psi
408         ) const
409         {
410             psi.clear();
411             for (unsigned long i = 0; i < sample.number_of_nodes(); ++i)
412             {
413                 // accumulate the node vectors
414                 if (label[i] == true)
415                     add_to_sparse_vect(psi, sample.node(i).data, edge_dims);
416 
417                 for (unsigned long n = 0; n < sample.node(i).number_of_neighbors(); ++n)
418                 {
419                     const unsigned long j = sample.node(i).neighbor(n).index();
420 
421                     // Don't double count edges.  Also only include the vector if
422                     // the labels disagree.
423                     if (i < j && label[i] != label[j])
424                     {
425                         subtract_from_sparse_vect(psi, sample.node(i).edge(n));
426                     }
427                 }
428             }
429         }
430 
431         virtual void get_truth_joint_feature_vector (
432             long idx,
433             feature_vector_type& psi
434         ) const
435         {
436             get_joint_feature_vector(samples[idx], labels[idx], psi);
437         }
438 
439         virtual void separation_oracle (
440             const long idx,
441             const matrix_type& current_solution,
442             double& loss,
443             feature_vector_type& psi
444         ) const
445         {
446             const sample_type& samp = samples[idx];
447 
448             // setup the potts graph based on samples[idx] and current_solution.
449             graph<double,double>::kernel_1a g;
450             copy_graph_structure(samp, g);
451             for (unsigned long i = 0; i < g.number_of_nodes(); ++i)
452             {
453                 g.node(i).data = dot(rowm(current_solution,range(edge_dims,current_solution.size()-1)),
454                                     samp.node(i).data);
455 
456                 // Include a loss augmentation so that we will get the proper loss augmented
457                 // max when we use find_max_factor_graph_potts() below.
458                 if (labels[idx][i])
459                     g.node(i).data -= get_loss_for_sample(idx,i,!labels[idx][i]);
460                 else
461                     g.node(i).data += get_loss_for_sample(idx,i,!labels[idx][i]);
462 
463                 for (unsigned long n = 0; n < g.node(i).number_of_neighbors(); ++n)
464                 {
465                     const unsigned long j = g.node(i).neighbor(n).index();
466                     // Don't compute an edge weight more than once.
467                     if (i < j)
468                     {
469                         g.node(i).edge(n) = dot(rowm(current_solution,range(0,edge_dims-1)),
470                                                 samp.node(i).edge(n));
471                     }
472                 }
473 
474             }
475 
476             std::vector<node_label> labeling;
477             find_max_factor_graph_potts(g, labeling);
478 
479 
480             std::vector<bool> bool_labeling;
481             bool_labeling.reserve(labeling.size());
482             // figure out the loss
483             loss = 0;
484             for (unsigned long i = 0; i < labeling.size(); ++i)
485             {
486                 const bool predicted_label = (labeling[i]!= 0);
487                 bool_labeling.push_back(predicted_label);
488                 loss += get_loss_for_sample(idx, i, predicted_label);
489             }
490 
491             // compute psi
492             get_joint_feature_vector(samp, bool_labeling, psi);
493         }
494 
495         double get_loss_for_sample (
496             long sample_idx,
497             long node_idx,
498             bool predicted_label
499         ) const
500         /*!
501             requires
502                 - 0 <= sample_idx < labels.size()
503                 - 0 <= node_idx < labels[sample_idx].size()
504             ensures
505                 - returns the loss incurred for predicting that the node
506                   samples[sample_idx].node(node_idx) has a label of predicted_label.
507         !*/
508         {
509                 const bool true_label = labels[sample_idx][node_idx];
510                 if (true_label != predicted_label)
511                 {
512                     if (losses.size() != 0)
513                         return losses[sample_idx][node_idx];
514                     else if (true_label == true)
515                         return loss_pos;
516                     else
517                         return loss_neg;
518                 }
519                 else
520                 {
521                     // no loss for making the correct prediction.
522                     return 0;
523                 }
524         }
525 
526         const dlib::array<sample_type>& samples;
527         const std::vector<label_type>& labels;
528         const std::vector<std::vector<double> >& losses;
529 
530         long node_dims;
531         long edge_dims;
532         double loss_pos;
533         double loss_neg;
534     };
535 
536 // ----------------------------------------------------------------------------------------
537 
538 }
539 
540 #endif // DLIB_STRUCTURAL_SVM_GRAPH_LAbELING_PROBLEM_Hh_
541 
542 
543