1 // Copyright (C) 2008  Davis E. King (davis@dlib.net)
2 // License: Boost Software License   See LICENSE.txt for the full license.
3 #ifndef DLIB_KCENTROId_
4 #define DLIB_KCENTROId_
5 
6 #include <vector>
7 
8 #include "kcentroid_abstract.h"
9 #include "../matrix.h"
10 #include "function.h"
11 #include "../std_allocator.h"
12 
13 namespace dlib
14 {
15 
16 // ----------------------------------------------------------------------------------------
17 
18     template <typename kernel_type>
19     class kcentroid
20     {
21         /*!
22             This object represents a weighted sum of sample points in a kernel induced
23             feature space.  It can be used to kernelize any algorithm that requires only
24             the ability to perform vector addition, subtraction, scalar multiplication,
25             and inner products.  It uses the sparsification technique described in the
26             paper The Kernel Recursive Least Squares Algorithm by Yaakov Engel.
27 
28             To understand the code it would also be useful to consult page 114 of the book
29             Kernel Methods for Pattern Analysis by Taylor and Cristianini as well as page 554
30             (particularly equation 18.31) of the book Learning with Kernels by Scholkopf and
31             Smola.  Everything you really need to know is in the Engel paper.  But the other
32             books help give more perspective on the issues involved.
33 
34 
35             INITIAL VALUE
36                 - min_strength == 0
37                 - min_vect_idx == 0
38                 - K_inv.size() == 0
39                 - K.size() == 0
40                 - dictionary.size() == 0
41                 - bias == 0
42                 - bias_is_stale == false
43 
44             CONVENTION
45                 - max_dictionary_size() == my_max_dictionary_size
46                 - get_kernel() == kernel
47 
48                 - K.nr() == dictionary.size()
49                 - K.nc() == dictionary.size()
50                 - for all valid r,c:
51                     - K(r,c) == kernel(dictionary[r], dictionary[c])
52                 - K_inv == inv(K)
53 
54                 - if (dictionary.size() == my_max_dictionary_size && my_remove_oldest_first == false) then
55                     - for all valid 0 < i < dictionary.size():
56                         - Let STRENGTHS[i] == the delta you would get for dictionary[i] (i.e. Approximately
57                           Linearly Dependent value) if you removed dictionary[i] from this object and then
58                           tried to add it back in.
59                         - min_strength == the minimum value from STRENGTHS
60                         - min_vect_idx == the index of the element in STRENGTHS with the smallest value
61 
62         !*/
63 
64     public:
65         typedef typename kernel_type::scalar_type scalar_type;
66         typedef typename kernel_type::sample_type sample_type;
67         typedef typename kernel_type::mem_manager_type mem_manager_type;
68 
kcentroid()69         kcentroid (
70         ) :
71             my_remove_oldest_first(false),
72             my_tolerance(0.001),
73             my_max_dictionary_size(1000000),
74             bias(0),
75             bias_is_stale(false)
76         {
77             clear_dictionary();
78         }
79 
80         explicit kcentroid (
81             const kernel_type& kernel_,
82             scalar_type tolerance_ = 0.001,
83             unsigned long max_dictionary_size_ = 1000000,
84             bool remove_oldest_first_ = false
85         ) :
my_remove_oldest_first(remove_oldest_first_)86             my_remove_oldest_first(remove_oldest_first_),
87             kernel(kernel_),
88             my_tolerance(tolerance_),
89             my_max_dictionary_size(max_dictionary_size_),
90             bias(0),
91             bias_is_stale(false)
92         {
93             // make sure requires clause is not broken
94             DLIB_ASSERT(tolerance_ > 0 && max_dictionary_size_ > 1,
95                 "\tkcentroid::kcentroid()"
96                 << "\n\t You have to give a positive tolerance"
97                 << "\n\t this:                 " << this
98                 << "\n\t tolerance_:           " << tolerance_
99                 << "\n\t max_dictionary_size_: " << max_dictionary_size_
100                 );
101 
102             clear_dictionary();
103         }
104 
tolerance()105         scalar_type tolerance() const
106         {
107             return my_tolerance;
108         }
109 
max_dictionary_size()110         unsigned long max_dictionary_size() const
111         {
112             return my_max_dictionary_size;
113         }
114 
remove_oldest_first()115         bool remove_oldest_first (
116         ) const
117         {
118             return my_remove_oldest_first;
119         }
120 
get_kernel()121         const kernel_type& get_kernel (
122         ) const
123         {
124             return kernel;
125         }
126 
clear_dictionary()127         void clear_dictionary ()
128         {
129             dictionary.clear();
130             alpha.clear();
131 
132             min_strength = 0;
133             min_vect_idx = 0;
134             K_inv.set_size(0,0);
135             K.set_size(0,0);
136             samples_seen = 0;
137             bias = 0;
138             bias_is_stale = false;
139         }
140 
operator()141         scalar_type operator() (
142             const kcentroid& x
143         ) const
144         {
145             // make sure requires clause is not broken
146             DLIB_ASSERT(x.get_kernel() == get_kernel(),
147                 "\tscalar_type kcentroid::operator()(const kcentroid& x)"
148                 << "\n\tYou can only compare two kcentroid objects if they use the same kernel"
149                 << "\n\tthis: " << this
150                 );
151 
152             // make sure the bias terms are up to date
153             refresh_bias();
154             x.refresh_bias();
155 
156             scalar_type temp = x.bias + bias - 2*inner_product(x);
157 
158             if (temp > 0)
159                 return std::sqrt(temp);
160             else
161                 return 0;
162         }
163 
inner_product(const sample_type & x)164         scalar_type inner_product (
165             const sample_type& x
166         ) const
167         {
168             scalar_type temp = 0;
169             for (unsigned long i = 0; i < alpha.size(); ++i)
170                 temp += alpha[i]*kernel(dictionary[i], x);
171             return temp;
172         }
173 
inner_product(const kcentroid & x)174         scalar_type inner_product (
175             const kcentroid& x
176         ) const
177         {
178             // make sure requires clause is not broken
179             DLIB_ASSERT(x.get_kernel() == get_kernel(),
180                 "\tscalar_type kcentroid::inner_product(const kcentroid& x)"
181                 << "\n\tYou can only compare two kcentroid objects if they use the same kernel"
182                 << "\n\tthis: " << this
183                 );
184 
185             scalar_type temp = 0;
186             for (unsigned long i = 0; i < alpha.size(); ++i)
187             {
188                 for (unsigned long j = 0; j < x.alpha.size(); ++j)
189                 {
190                     temp += alpha[i]*x.alpha[j]*kernel(dictionary[i], x.dictionary[j]);
191                 }
192             }
193             return temp;
194         }
195 
squared_norm()196         scalar_type squared_norm (
197         ) const
198         {
199             refresh_bias();
200             return bias;
201         }
202 
operator()203         scalar_type operator() (
204             const sample_type& x
205         ) const
206         {
207             // make sure the bias terms are up to date
208             refresh_bias();
209 
210             const scalar_type kxx = kernel(x,x);
211 
212             scalar_type temp = kxx + bias - 2*inner_product(x);
213             if (temp > 0)
214                 return std::sqrt(temp);
215             else
216                 return 0;
217         }
218 
samples_trained()219         scalar_type samples_trained (
220         ) const
221         {
222             return samples_seen;
223         }
224 
test_and_train(const sample_type & x)225         scalar_type test_and_train (
226             const sample_type& x
227         )
228         {
229             ++samples_seen;
230             const scalar_type xscale = 1/samples_seen;
231             const scalar_type cscale = 1-xscale;
232             return train_and_maybe_test(x,cscale,xscale,true);
233         }
234 
train(const sample_type & x)235         void train (
236             const sample_type& x
237         )
238         {
239             ++samples_seen;
240             const scalar_type xscale = 1/samples_seen;
241             const scalar_type cscale = 1-xscale;
242             train_and_maybe_test(x,cscale,xscale,false);
243         }
244 
test_and_train(const sample_type & x,scalar_type cscale,scalar_type xscale)245         scalar_type test_and_train (
246             const sample_type& x,
247             scalar_type cscale,
248             scalar_type xscale
249         )
250         {
251             ++samples_seen;
252             return train_and_maybe_test(x,cscale,xscale,true);
253         }
254 
scale_by(scalar_type cscale)255         void scale_by (
256             scalar_type cscale
257         )
258         {
259             for (unsigned long i = 0; i < alpha.size(); ++i)
260             {
261                 alpha[i] = cscale*alpha[i];
262             }
263         }
264 
train(const sample_type & x,scalar_type cscale,scalar_type xscale)265         void train (
266             const sample_type& x,
267             scalar_type cscale,
268             scalar_type xscale
269         )
270         {
271             ++samples_seen;
272             train_and_maybe_test(x,cscale,xscale,false);
273         }
274 
swap(kcentroid & item)275         void swap (
276             kcentroid& item
277         )
278         {
279             exchange(min_strength, item.min_strength);
280             exchange(min_vect_idx, item.min_vect_idx);
281             exchange(my_remove_oldest_first, item.my_remove_oldest_first);
282 
283             exchange(kernel, item.kernel);
284             dictionary.swap(item.dictionary);
285             alpha.swap(item.alpha);
286             K_inv.swap(item.K_inv);
287             K.swap(item.K);
288             exchange(my_tolerance, item.my_tolerance);
289             exchange(samples_seen, item.samples_seen);
290             exchange(bias, item.bias);
291             a.swap(item.a);
292             k.swap(item.k);
293             exchange(bias_is_stale, item.bias_is_stale);
294             exchange(my_max_dictionary_size, item.my_max_dictionary_size);
295         }
296 
dictionary_size()297         unsigned long dictionary_size (
298         ) const { return dictionary.size(); }
299 
serialize(const kcentroid & item,std::ostream & out)300         friend void serialize(const kcentroid& item, std::ostream& out)
301         {
302             serialize(item.min_strength, out);
303             serialize(item.min_vect_idx, out);
304             serialize(item.my_remove_oldest_first, out);
305 
306             serialize(item.kernel, out);
307             serialize(item.dictionary, out);
308             serialize(item.alpha, out);
309             serialize(item.K_inv, out);
310             serialize(item.K, out);
311             serialize(item.my_tolerance, out);
312             serialize(item.samples_seen, out);
313             serialize(item.bias, out);
314             serialize(item.bias_is_stale, out);
315             serialize(item.my_max_dictionary_size, out);
316         }
317 
deserialize(kcentroid & item,std::istream & in)318         friend void deserialize(kcentroid& item, std::istream& in)
319         {
320             deserialize(item.min_strength, in);
321             deserialize(item.min_vect_idx, in);
322             deserialize(item.my_remove_oldest_first, in);
323 
324             deserialize(item.kernel, in);
325             deserialize(item.dictionary, in);
326             deserialize(item.alpha, in);
327             deserialize(item.K_inv, in);
328             deserialize(item.K, in);
329             deserialize(item.my_tolerance, in);
330             deserialize(item.samples_seen, in);
331             deserialize(item.bias, in);
332             deserialize(item.bias_is_stale, in);
333             deserialize(item.my_max_dictionary_size, in);
334         }
335 
get_distance_function()336         distance_function<kernel_type> get_distance_function (
337         ) const
338         {
339             refresh_bias();
340             return distance_function<kernel_type>(mat(alpha),
341                                                   bias,
342                                                   kernel,
343                                                   mat(dictionary));
344         }
345 
346     private:
347 
refresh_bias()348         void refresh_bias (
349         ) const
350         {
351             if (bias_is_stale)
352             {
353                 bias_is_stale = false;
354                 // recompute the bias term
355                 bias = sum(pointwise_multiply(K, mat(alpha)*trans(mat(alpha))));
356             }
357         }
358 
train_and_maybe_test(const sample_type & x,scalar_type cscale,scalar_type xscale,bool do_test)359         scalar_type train_and_maybe_test (
360             const sample_type& x,
361             scalar_type cscale,
362             scalar_type xscale,
363             bool do_test
364         )
365         {
366             scalar_type test_result = 0;
367             const scalar_type kx = kernel(x,x);
368             if (alpha.size() == 0)
369             {
370                 // just ignore this sample if it is the zero vector (or really close to being zero)
371                 if (std::abs(kx) > std::numeric_limits<scalar_type>::epsilon())
372                 {
373                     // set initial state since this is the first training example we have seen
374 
375                     K_inv.set_size(1,1);
376                     K_inv(0,0) = 1/kx;
377                     K.set_size(1,1);
378                     K(0,0) = kx;
379 
380                     alpha.push_back(xscale);
381                     dictionary.push_back(x);
382                 }
383                 else
384                 {
385                     // the distance from an empty kcentroid and the zero vector is zero by definition.
386                     return 0;
387                 }
388             }
389             else
390             {
391                 // fill in k
392                 k.set_size(alpha.size());
393                 for (long r = 0; r < k.nr(); ++r)
394                     k(r) = kernel(x,dictionary[r]);
395 
396                 if (do_test)
397                 {
398                     refresh_bias();
399                     test_result = std::sqrt(kx + bias - 2*trans(mat(alpha))*k);
400                 }
401 
402                 // compute the error we would have if we approximated the new x sample
403                 // with the dictionary.  That is, do the ALD test from the KRLS paper.
404                 a = K_inv*k;
405                 scalar_type delta = kx - trans(k)*a;
406 
407                 // if this new vector isn't approximately linearly dependent on the vectors
408                 // in our dictionary.
409                 if (delta > min_strength && delta > my_tolerance)
410                 {
411                     bool need_to_update_min_strength = false;
412                     if (dictionary.size() >= my_max_dictionary_size)
413                     {
414                         // We need to remove one of the old members of the dictionary before
415                         // we proceed with adding a new one.
416                         long idx_to_remove;
417                         if (my_remove_oldest_first)
418                         {
419                             // remove the oldest one
420                             idx_to_remove = 0;
421                         }
422                         else
423                         {
424                             // if we have never computed the min_strength then we should compute it
425                             if (min_strength == 0)
426                                 recompute_min_strength();
427 
428                             // select the dictionary vector that is most linearly dependent for removal
429                             idx_to_remove = min_vect_idx;
430                             need_to_update_min_strength = true;
431                         }
432 
433                         remove_dictionary_vector(idx_to_remove);
434 
435                         // recompute these guys since they were computed with the old
436                         // kernel matrix
437                         k = remove_row(k,idx_to_remove);
438                         a = K_inv*k;
439                         delta = kx - trans(k)*a;
440                     }
441 
442                     // add x to the dictionary
443                     dictionary.push_back(x);
444 
445 
446                     // update K_inv by computing the new one in the temp matrix (equation 3.14)
447                     matrix<scalar_type,0,0,mem_manager_type> temp(K_inv.nr()+1, K_inv.nc()+1);
448                     // update the middle part of the matrix
449                     set_subm(temp, get_rect(K_inv)) = K_inv + a*trans(a)/delta;
450                     // update the right column of the matrix
451                     set_subm(temp, 0, K_inv.nr(),K_inv.nr(),1) = -a/delta;
452                     // update the bottom row of the matrix
453                     set_subm(temp, K_inv.nr(), 0, 1, K_inv.nr()) = trans(-a/delta);
454                     // update the bottom right corner of the matrix
455                     temp(K_inv.nr(), K_inv.nc()) = 1/delta;
456                     // put temp into K_inv
457                     temp.swap(K_inv);
458 
459 
460 
461                     // update K (the kernel matrix)
462                     temp.set_size(K.nr()+1, K.nc()+1);
463                     set_subm(temp, get_rect(K)) = K;
464                     // update the right column of the matrix
465                     set_subm(temp, 0, K.nr(),K.nr(),1) = k;
466                     // update the bottom row of the matrix
467                     set_subm(temp, K.nr(), 0, 1, K.nr()) = trans(k);
468                     temp(K.nr(), K.nc()) = kx;
469                     // put temp into K
470                     temp.swap(K);
471 
472 
473                     // now update the alpha vector
474                     for (unsigned long i = 0; i < alpha.size(); ++i)
475                     {
476                         alpha[i] *= cscale;
477                     }
478                     alpha.push_back(xscale);
479 
480 
481                     if (need_to_update_min_strength)
482                     {
483                         // now we have to recompute the min_strength in this case
484                         recompute_min_strength();
485                     }
486                 }
487                 else
488                 {
489                     // update the alpha vector so that this new sample has been added into
490                     // the mean vector we are accumulating
491                     for (unsigned long i = 0; i < alpha.size(); ++i)
492                     {
493                         alpha[i] = cscale*alpha[i] + xscale*a(i);
494                     }
495                 }
496             }
497 
498             bias_is_stale = true;
499 
500             return test_result;
501         }
502 
remove_dictionary_vector(long i)503         void remove_dictionary_vector (
504             long i
505         )
506         /*!
507             requires
508                 - 0 <= i < dictionary.size()
509             ensures
510                 - #dictionary.size() == dictionary.size() - 1
511                 - #alpha.size() == alpha.size() - 1
512                 - updates the K_inv matrix so that it is still a proper inverse of the
513                   kernel matrix
514                 - also removes the necessary row and column from the K matrix
515                 - uses the this->a variable so after this function runs that variable
516                   will contain a different value.
517         !*/
518         {
519             // remove the dictionary vector
520             dictionary.erase(dictionary.begin()+i);
521 
522             // remove the i'th vector from the inverse kernel matrix.  This formula is basically
523             // just the reverse of the way K_inv is updated by equation 3.14 during normal training.
524             K_inv = removerc(K_inv,i,i) - remove_row(colm(K_inv,i)/K_inv(i,i),i)*remove_col(rowm(K_inv,i),i);
525 
526             // now compute the updated alpha values to take account that we just removed one of
527             // our dictionary vectors
528             a = (K_inv*remove_row(K,i)*mat(alpha));
529 
530             // now copy over the new alpha values
531             alpha.resize(alpha.size()-1);
532             for (unsigned long k = 0; k < alpha.size(); ++k)
533             {
534                 alpha[k] = a(k);
535             }
536 
537             // update the K matrix as well
538             K = removerc(K,i,i);
539         }
540 
recompute_min_strength()541         void recompute_min_strength (
542         )
543         /*!
544             ensures
545                 - recomputes the min_strength and min_vect_idx values
546                   so that they are correct with respect to the CONVENTION
547                 - uses the this->a variable so after this function runs that variable
548                   will contain a different value.
549         !*/
550         {
551             min_strength = std::numeric_limits<scalar_type>::max();
552 
553             // here we loop over each dictionary vector and compute what its delta would be if
554             // we were to remove it from the dictionary and then try to add it back in.
555             for (unsigned long i = 0; i < dictionary.size(); ++i)
556             {
557                 // compute a = K_inv*k but where dictionary vector i has been removed
558                 a = (removerc(K_inv,i,i) - remove_row(colm(K_inv,i)/K_inv(i,i),i)*remove_col(rowm(K_inv,i),i)) *
559                     (remove_row(colm(K,i),i));
560                 scalar_type delta = K(i,i) - trans(remove_row(colm(K,i),i))*a;
561 
562                 if (delta < min_strength)
563                 {
564                     min_strength = delta;
565                     min_vect_idx = i;
566                 }
567             }
568         }
569 
570 
571 
572         typedef std_allocator<sample_type, mem_manager_type> alloc_sample_type;
573         typedef std_allocator<scalar_type, mem_manager_type> alloc_scalar_type;
574         typedef std::vector<sample_type,alloc_sample_type> dictionary_vector_type;
575         typedef std::vector<scalar_type,alloc_scalar_type> alpha_vector_type;
576 
577 
578         scalar_type min_strength;
579         unsigned long min_vect_idx;
580         bool my_remove_oldest_first;
581 
582         kernel_type kernel;
583         dictionary_vector_type dictionary;
584         alpha_vector_type alpha;
585 
586         matrix<scalar_type,0,0,mem_manager_type> K_inv;
587         matrix<scalar_type,0,0,mem_manager_type> K;
588 
589         scalar_type my_tolerance;
590         unsigned long my_max_dictionary_size;
591         scalar_type samples_seen;
592         mutable scalar_type bias;
593         mutable bool bias_is_stale;
594 
595 
596         // temp variables here just so we don't have to reconstruct them over and over.  Thus,
597         // they aren't really part of the state of this object.
598         matrix<scalar_type,0,1,mem_manager_type> a;
599         matrix<scalar_type,0,1,mem_manager_type> k;
600 
601     };
602 
603 // ----------------------------------------------------------------------------------------
604 
605     template <typename kernel_type>
swap(kcentroid<kernel_type> & a,kcentroid<kernel_type> & b)606     void swap(kcentroid<kernel_type>& a, kcentroid<kernel_type>& b)
607     { a.swap(b); }
608 
609 // ----------------------------------------------------------------------------------------
610 
611 }
612 
613 #endif // DLIB_KCENTROId_
614 
615