1 // Copyright (C) 2007  Davis E. King (davis@dlib.net)
2 // License: Boost Software License   See LICENSE.txt for the full license.
3 #ifndef DLIB_SVm_KERNEL
4 #define DLIB_SVm_KERNEL
5 
6 #include "kernel_abstract.h"
7 #include <cmath>
8 #include <limits>
9 #include <sstream>
10 #include "../matrix.h"
11 #include "../algs.h"
12 #include "../serialize.h"
13 
14 namespace dlib
15 {
16 
17 // ----------------------------------------------------------------------------------------
18 
19     template < typename kernel_type > struct kernel_derivative;
20 
21 // ----------------------------------------------------------------------------------------
22 
23     template <
24         typename T
25         >
26     struct radial_basis_kernel
27     {
28         typedef typename T::type scalar_type;
29         typedef T sample_type;
30         typedef typename T::mem_manager_type mem_manager_type;
31 
32         // T must be capable of representing a column vector.
33         COMPILE_TIME_ASSERT(T::NC == 1 || T::NC == 0);
34 
radial_basis_kernelradial_basis_kernel35         radial_basis_kernel(const scalar_type g) : gamma(g) {}
radial_basis_kernelradial_basis_kernel36         radial_basis_kernel() : gamma(0.1) {}
radial_basis_kernelradial_basis_kernel37         radial_basis_kernel(
38             const radial_basis_kernel& k
39         ) : gamma(k.gamma) {}
40 
41 
42         const scalar_type gamma;
43 
operatorradial_basis_kernel44         scalar_type operator() (
45             const sample_type& a,
46             const sample_type& b
47         ) const
48         {
49             const scalar_type d = trans(a-b)*(a-b);
50             return std::exp(-gamma*d);
51         }
52 
53         radial_basis_kernel& operator= (
54             const radial_basis_kernel& k
55         )
56         {
57             const_cast<scalar_type&>(gamma) = k.gamma;
58             return *this;
59         }
60 
61         bool operator== (
62             const radial_basis_kernel& k
63         ) const
64         {
65             return gamma == k.gamma;
66         }
67     };
68 
69     template <
70         typename T
71         >
serialize(const radial_basis_kernel<T> & item,std::ostream & out)72     void serialize (
73         const radial_basis_kernel<T>& item,
74         std::ostream& out
75     )
76     {
77         try
78         {
79             serialize(item.gamma, out);
80         }
81         catch (serialization_error& e)
82         {
83             throw serialization_error(e.info + "\n   while serializing object of type radial_basis_kernel");
84         }
85     }
86 
87     template <
88         typename T
89         >
deserialize(radial_basis_kernel<T> & item,std::istream & in)90     void deserialize (
91         radial_basis_kernel<T>& item,
92         std::istream& in
93     )
94     {
95         typedef typename T::type scalar_type;
96         try
97         {
98             deserialize(const_cast<scalar_type&>(item.gamma), in);
99         }
100         catch (serialization_error& e)
101         {
102             throw serialization_error(e.info + "\n   while deserializing object of type radial_basis_kernel");
103         }
104     }
105 
106     template <
107         typename T
108         >
109     struct kernel_derivative<radial_basis_kernel<T> >
110     {
111         typedef typename T::type scalar_type;
112         typedef T sample_type;
113         typedef typename T::mem_manager_type mem_manager_type;
114 
115         kernel_derivative(const radial_basis_kernel<T>& k_) : k(k_){}
116 
117         const sample_type& operator() (const sample_type& x, const sample_type& y) const
118         {
119             // return the derivative of the rbf kernel
120             temp = 2*k.gamma*(x-y)*k(x,y);
121             return temp;
122         }
123 
124         const radial_basis_kernel<T>& k;
125         mutable sample_type temp;
126     };
127 
128 // ----------------------------------------------------------------------------------------
129 
130     template <
131         typename T
132         >
133     struct polynomial_kernel
134     {
135         typedef typename T::type scalar_type;
136         typedef T sample_type;
137         typedef typename T::mem_manager_type mem_manager_type;
138 
139         // T must be capable of representing a column vector.
140         COMPILE_TIME_ASSERT(T::NC == 1 || T::NC == 0);
141 
142         polynomial_kernel(const scalar_type g, const scalar_type c, const scalar_type d) : gamma(g), coef(c), degree(d) {}
143         polynomial_kernel() : gamma(1), coef(0), degree(1) {}
144         polynomial_kernel(
145             const polynomial_kernel& k
146         ) : gamma(k.gamma), coef(k.coef), degree(k.degree) {}
147 
148         typedef T type;
149         const scalar_type gamma;
150         const scalar_type coef;
151         const scalar_type degree;
152 
153         scalar_type operator() (
154             const sample_type& a,
155             const sample_type& b
156         ) const
157         {
158             return std::pow(gamma*(trans(a)*b) + coef, degree);
159         }
160 
161         polynomial_kernel& operator= (
162             const polynomial_kernel& k
163         )
164         {
165             const_cast<scalar_type&>(gamma) = k.gamma;
166             const_cast<scalar_type&>(coef) = k.coef;
167             const_cast<scalar_type&>(degree) = k.degree;
168             return *this;
169         }
170 
171         bool operator== (
172             const polynomial_kernel& k
173         ) const
174         {
175             return (gamma == k.gamma) && (coef == k.coef) && (degree == k.degree);
176         }
177     };
178 
179     template <
180         typename T
181         >
182     void serialize (
183         const polynomial_kernel<T>& item,
184         std::ostream& out
185     )
186     {
187         try
188         {
189             serialize(item.gamma, out);
190             serialize(item.coef, out);
191             serialize(item.degree, out);
192         }
193         catch (serialization_error& e)
194         {
195             throw serialization_error(e.info + "\n   while serializing object of type polynomial_kernel");
196         }
197     }
198 
199     template <
200         typename T
201         >
202     void deserialize (
203         polynomial_kernel<T>& item,
204         std::istream& in
205     )
206     {
207         typedef typename T::type scalar_type;
208         try
209         {
210             deserialize(const_cast<scalar_type&>(item.gamma), in);
211             deserialize(const_cast<scalar_type&>(item.coef), in);
212             deserialize(const_cast<scalar_type&>(item.degree), in);
213         }
214         catch (serialization_error& e)
215         {
216             throw serialization_error(e.info + "\n   while deserializing object of type polynomial_kernel");
217         }
218     }
219 
220     template <
221         typename T
222         >
223     struct kernel_derivative<polynomial_kernel<T> >
224     {
225         typedef typename T::type scalar_type;
226         typedef T sample_type;
227         typedef typename T::mem_manager_type mem_manager_type;
228 
229         kernel_derivative(const polynomial_kernel<T>& k_) : k(k_){}
230 
231         const sample_type& operator() (const sample_type& x, const sample_type& y) const
232         {
233             // return the derivative of the rbf kernel
234             temp = k.degree*k.gamma*x*std::pow(k.gamma*(trans(x)*y) + k.coef, k.degree-1);
235             return temp;
236         }
237 
238         const polynomial_kernel<T>& k;
239         mutable sample_type temp;
240     };
241 
242 // ----------------------------------------------------------------------------------------
243 
244     template <
245         typename T
246         >
247     struct sigmoid_kernel
248     {
249         typedef typename T::type scalar_type;
250         typedef T sample_type;
251         typedef typename T::mem_manager_type mem_manager_type;
252 
253         // T must be capable of representing a column vector.
254         COMPILE_TIME_ASSERT(T::NC == 1 || T::NC == 0);
255 
256         sigmoid_kernel(const scalar_type g, const scalar_type c) : gamma(g), coef(c) {}
257         sigmoid_kernel() : gamma(0.1), coef(-1.0) {}
258         sigmoid_kernel(
259             const sigmoid_kernel& k
260         ) : gamma(k.gamma), coef(k.coef) {}
261 
262         typedef T type;
263         const scalar_type gamma;
264         const scalar_type coef;
265 
266         scalar_type operator() (
267             const sample_type& a,
268             const sample_type& b
269         ) const
270         {
271             return std::tanh(gamma*(trans(a)*b) + coef);
272         }
273 
274         sigmoid_kernel& operator= (
275             const sigmoid_kernel& k
276         )
277         {
278             const_cast<scalar_type&>(gamma) = k.gamma;
279             const_cast<scalar_type&>(coef) = k.coef;
280             return *this;
281         }
282 
283         bool operator== (
284             const sigmoid_kernel& k
285         ) const
286         {
287             return (gamma == k.gamma) && (coef == k.coef);
288         }
289     };
290 
291     template <
292         typename T
293         >
294     void serialize (
295         const sigmoid_kernel<T>& item,
296         std::ostream& out
297     )
298     {
299         try
300         {
301             serialize(item.gamma, out);
302             serialize(item.coef, out);
303         }
304         catch (serialization_error& e)
305         {
306             throw serialization_error(e.info + "\n   while serializing object of type sigmoid_kernel");
307         }
308     }
309 
310     template <
311         typename T
312         >
313     void deserialize (
314         sigmoid_kernel<T>& item,
315         std::istream& in
316     )
317     {
318         typedef typename T::type scalar_type;
319         try
320         {
321             deserialize(const_cast<scalar_type&>(item.gamma), in);
322             deserialize(const_cast<scalar_type&>(item.coef), in);
323         }
324         catch (serialization_error& e)
325         {
326             throw serialization_error(e.info + "\n   while deserializing object of type sigmoid_kernel");
327         }
328     }
329 
330     template <
331         typename T
332         >
333     struct kernel_derivative<sigmoid_kernel<T> >
334     {
335         typedef typename T::type scalar_type;
336         typedef T sample_type;
337         typedef typename T::mem_manager_type mem_manager_type;
338 
339         kernel_derivative(const sigmoid_kernel<T>& k_) : k(k_){}
340 
341         const sample_type& operator() (const sample_type& x, const sample_type& y) const
342         {
343             // return the derivative of the rbf kernel
344             temp = k.gamma*x*(1-std::pow(k(x,y),2));
345             return temp;
346         }
347 
348         const sigmoid_kernel<T>& k;
349         mutable sample_type temp;
350     };
351 
352 // ----------------------------------------------------------------------------------------
353 
354     template <typename T>
355     struct linear_kernel
356     {
357         typedef typename T::type scalar_type;
358         typedef T sample_type;
359         typedef typename T::mem_manager_type mem_manager_type;
360 
361         // T must be capable of representing a column vector.
362         COMPILE_TIME_ASSERT(T::NC == 1 || T::NC == 0);
363 
364         scalar_type operator() (
365             const sample_type& a,
366             const sample_type& b
367         ) const
368         {
369             return trans(a)*b;
370         }
371 
372         bool operator== (
373             const linear_kernel&
374         ) const
375         {
376             return true;
377         }
378     };
379 
380     template <
381         typename T
382         >
383     void serialize (
384         const linear_kernel<T>& ,
385         std::ostream&
386     ){}
387 
388     template <
389         typename T
390         >
391     void deserialize (
392         linear_kernel<T>& ,
393         std::istream&
394     ){}
395 
396     template <
397         typename T
398         >
399     struct kernel_derivative<linear_kernel<T> >
400     {
401         typedef typename T::type scalar_type;
402         typedef T sample_type;
403         typedef typename T::mem_manager_type mem_manager_type;
404 
405         kernel_derivative(const linear_kernel<T>& k_) : k(k_){}
406 
407         const sample_type& operator() (const sample_type& x, const sample_type& ) const
408         {
409             return x;
410         }
411 
412         const linear_kernel<T>& k;
413     };
414 
415 // ----------------------------------------------------------------------------------------
416 
417     template <typename T>
418     struct histogram_intersection_kernel
419     {
420         typedef typename T::type scalar_type;
421         typedef T sample_type;
422         typedef typename T::mem_manager_type mem_manager_type;
423 
424         scalar_type operator() (
425             const sample_type& a,
426             const sample_type& b
427         ) const
428         {
429             scalar_type temp = 0;
430             for (long i = 0; i < a.size(); ++i)
431             {
432                 temp += std::min(a(i), b(i));
433             }
434             return temp;
435         }
436 
437         bool operator== (
438             const histogram_intersection_kernel&
439         ) const
440         {
441             return true;
442         }
443     };
444 
445     template <
446         typename T
447         >
448     void serialize (
449         const histogram_intersection_kernel<T>& ,
450         std::ostream&
451     ){}
452 
453     template <
454         typename T
455         >
456     void deserialize (
457         histogram_intersection_kernel<T>& ,
458         std::istream&
459     ){}
460 
461 // ----------------------------------------------------------------------------------------
462 
463     template <typename T>
464     struct offset_kernel
465     {
466         typedef typename T::scalar_type scalar_type;
467         typedef typename T::sample_type sample_type;
468         typedef typename T::mem_manager_type mem_manager_type;
469 
470         offset_kernel(const T& k, const scalar_type& offset_
471         ) : kernel(k), offset(offset_) {}
472         offset_kernel() : kernel(T()), offset(0.01) {}
473         offset_kernel(
474             const offset_kernel& k
475         ) : kernel(k.kernel), offset(k.offset) {}
476 
477         const T kernel;
478         const scalar_type offset;
479 
480         scalar_type operator() (
481             const sample_type& a,
482             const sample_type& b
483         ) const
484         {
485             return kernel(a,b) + offset;
486         }
487 
488         offset_kernel& operator= (
489             const offset_kernel& k
490         )
491         {
492             const_cast<T&>(kernel) = k.kernel;
493             const_cast<scalar_type&>(offset) = k.offset;
494             return *this;
495         }
496 
497         bool operator== (
498             const offset_kernel& k
499         ) const
500         {
501             return k.kernel == kernel && offset == k.offset;
502         }
503     };
504 
505     template <
506         typename T
507         >
508     void serialize (
509         const offset_kernel<T>& item,
510         std::ostream& out
511     )
512     {
513         try
514         {
515             serialize(item.offset, out);
516             serialize(item.kernel, out);
517         }
518         catch (serialization_error& e)
519         {
520             throw serialization_error(e.info + "\n   while serializing object of type offset_kernel");
521         }
522     }
523 
524     template <
525         typename T
526         >
527     void deserialize (
528         offset_kernel<T>& item,
529         std::istream& in
530     )
531     {
532         typedef typename offset_kernel<T>::scalar_type scalar_type;
533         try
534         {
535             deserialize(const_cast<scalar_type&>(item.offset), in);
536             deserialize(const_cast<T&>(item.kernel), in);
537         }
538         catch (serialization_error& e)
539         {
540             throw serialization_error(e.info + "\n   while deserializing object of type offset_kernel");
541         }
542     }
543 
544     template <
545         typename T
546         >
547     struct kernel_derivative<offset_kernel<T> >
548     {
549         typedef typename T::scalar_type scalar_type;
550         typedef typename T::sample_type sample_type;
551         typedef typename T::mem_manager_type mem_manager_type;
552 
553         kernel_derivative(const offset_kernel<T>& k) : der(k.kernel){}
554 
555         const sample_type operator() (const sample_type& x, const sample_type& y) const
556         {
557             return der(x,y);
558         }
559 
560         kernel_derivative<T> der;
561     };
562 
563 // ----------------------------------------------------------------------------------------
564 
565 }
566 
567 #endif // DLIB_SVm_KERNEL
568 
569 
570