1 // Copyright (C) 2007  Davis E. King (davis@dlib.net)
2 // License: Boost Software License   See LICENSE.txt for the full license.
3 #ifndef DLIB_MLp_KERNEL_1_
4 #define DLIB_MLp_KERNEL_1_
5 
6 #include "../algs.h"
7 #include "../serialize.h"
8 #include "../matrix.h"
9 #include "../rand.h"
10 #include "mlp_kernel_abstract.h"
11 #include <ctime>
12 #include <sstream>
13 
14 namespace dlib
15 {
16 
17     class mlp_kernel_1 : noncopyable
18     {
19         /*!
20             INITIAL VALUE
21                 The network is initially initialized with random weights
22 
23             CONVENTION
24                 - input_layer_nodes() == input_nodes
25                 - first_hidden_layer_nodes() == first_hidden_nodes
26                 - second_hidden_layer_nodes() == second_hidden_nodes
27                 - output_layer_nodes() == output_nodes
28                 - get_alpha == alpha
29                 - get_momentum() == momentum
30 
31 
32                 - if (second_hidden_nodes == 0) then
33                     - for all i and j:
34                         - w1(i,j) == the weight on the link from node i in the first hidden layer
35                           to input node j
36                         - w3(i,j) == the weight on the link from node i in the output layer
37                           to first hidden layer node j
38                     - for all i and j:
39                         - w1m == the momentum terms for w1 from the previous update
40                         - w3m == the momentum terms for w3 from the previous update
41                 - else
42                     - for all i and j:
43                         - w1(i,j) == the weight on the link from node i in the first hidden layer
44                           to input node j
45                         - w2(i,j) == the weight on the link from node i in the second hidden layer
46                           to first hidden layer node j
47                         - w3(i,j) == the weight on the link from node i in the output layer
48                           to second hidden layer node j
49                     - for all i and j:
50                         - w1m == the momentum terms for w1 from the previous update
51                         - w2m == the momentum terms for w2 from the previous update
52                         - w3m == the momentum terms for w3 from the previous update
53         !*/
54 
55     public:
56 
57         mlp_kernel_1 (
58             long nodes_in_input_layer,
59             long nodes_in_first_hidden_layer,
60             long nodes_in_second_hidden_layer = 0,
61             long nodes_in_output_layer = 1,
62             double alpha_ = 0.1,
63             double momentum_ = 0.8
64         ) :
input_nodes(nodes_in_input_layer)65             input_nodes(nodes_in_input_layer),
66             first_hidden_nodes(nodes_in_first_hidden_layer),
67             second_hidden_nodes(nodes_in_second_hidden_layer),
68             output_nodes(nodes_in_output_layer),
69             alpha(alpha_),
70             momentum(momentum_)
71         {
72 
73             // seed the random number generator
74             std::ostringstream sout;
75             sout << time(0);
76             rand_nums.set_seed(sout.str());
77 
78             w1.set_size(first_hidden_nodes+1, input_nodes+1);
79             w1m.set_size(first_hidden_nodes+1, input_nodes+1);
80             z.set_size(input_nodes+1,1);
81 
82             if (second_hidden_nodes != 0)
83             {
84                 w2.set_size(second_hidden_nodes+1, first_hidden_nodes+1);
85                 w3.set_size(output_nodes, second_hidden_nodes+1);
86 
87                 w2m.set_size(second_hidden_nodes+1, first_hidden_nodes+1);
88                 w3m.set_size(output_nodes, second_hidden_nodes+1);
89             }
90             else
91             {
92                 w3.set_size(output_nodes, first_hidden_nodes+1);
93 
94                 w3m.set_size(output_nodes, first_hidden_nodes+1);
95             }
96 
97             reset();
98         }
99 
~mlp_kernel_1()100         virtual ~mlp_kernel_1 (
101         ) {}
102 
reset()103         void reset (
104         )
105         {
106             // randomize the weights for the first layer
107             for (long r = 0; r < w1.nr(); ++r)
108                 for (long c = 0; c < w1.nc(); ++c)
109                     w1(r,c) = rand_nums.get_random_double();
110 
111             // randomize the weights for the second layer
112             for (long r = 0; r < w2.nr(); ++r)
113                 for (long c = 0; c < w2.nc(); ++c)
114                     w2(r,c) = rand_nums.get_random_double();
115 
116             // randomize the weights for the third layer
117             for (long r = 0; r < w3.nr(); ++r)
118                 for (long c = 0; c < w3.nc(); ++c)
119                     w3(r,c) = rand_nums.get_random_double();
120 
121             // zero all the momentum terms
122             set_all_elements(w1m,0);
123             set_all_elements(w2m,0);
124             set_all_elements(w3m,0);
125         }
126 
input_layer_nodes()127         long input_layer_nodes (
128         ) const { return input_nodes; }
129 
first_hidden_layer_nodes()130         long first_hidden_layer_nodes (
131         ) const { return first_hidden_nodes; }
132 
second_hidden_layer_nodes()133         long second_hidden_layer_nodes (
134         ) const { return second_hidden_nodes; }
135 
output_layer_nodes()136         long output_layer_nodes (
137         ) const { return output_nodes; }
138 
get_alpha()139         double get_alpha (
140         ) const { return alpha; }
141 
get_momentum()142         double get_momentum (
143         ) const { return momentum; }
144 
145         template <typename EXP>
operator()146         const matrix<double> operator() (
147             const matrix_exp<EXP>& in
148         ) const
149         {
150             for (long i = 0; i < in.nr(); ++i)
151                 z(i) = in(i);
152             // insert the bias
153             z(z.nr()-1) = -1;
154 
155             tmp1 = sigmoid(w1*z);
156             // insert the bias
157             tmp1(tmp1.nr()-1) = -1;
158 
159             if (second_hidden_nodes == 0)
160             {
161                 return sigmoid(w3*tmp1);
162             }
163             else
164             {
165                 tmp2 = sigmoid(w2*tmp1);
166                 // insert the bias
167                 tmp2(tmp2.nr()-1) = -1;
168 
169                 return sigmoid(w3*tmp2);
170             }
171         }
172 
173         template <typename EXP1, typename EXP2>
train(const matrix_exp<EXP1> & example_in,const matrix_exp<EXP2> & example_out)174         void train (
175             const matrix_exp<EXP1>& example_in,
176             const matrix_exp<EXP2>& example_out
177         )
178         {
179             for (long i = 0; i < example_in.nr(); ++i)
180                 z(i) = example_in(i);
181             // insert the bias
182             z(z.nr()-1) = -1;
183 
184             tmp1 = sigmoid(w1*z);
185             // insert the bias
186             tmp1(tmp1.nr()-1) = -1;
187 
188 
189             if (second_hidden_nodes == 0)
190             {
191                 o = sigmoid(w3*tmp1);
192 
193                 // now compute the errors and propagate them backwards though the network
194                 e3 = pointwise_multiply(example_out-o, uniform_matrix<double>(output_nodes,1,1.0)-o, o);
195                 e1 = pointwise_multiply(tmp1, uniform_matrix<double>(first_hidden_nodes+1,1,1.0) - tmp1, trans(w3)*e3 );
196 
197                 // compute the new weight updates
198                 w3m = alpha * e3*trans(tmp1) + w3m*momentum;
199                 w1m = alpha * e1*trans(z)    + w1m*momentum;
200 
201                 // now update the weights
202                 w1 += w1m;
203                 w3 += w3m;
204             }
205             else
206             {
207                 tmp2 = sigmoid(w2*tmp1);
208                 // insert the bias
209                 tmp2(tmp2.nr()-1) = -1;
210 
211                 o = sigmoid(w3*tmp2);
212 
213 
214                 // now compute the errors and propagate them backwards though the network
215                 e3 = pointwise_multiply(example_out-o, uniform_matrix<double>(output_nodes,1,1.0)-o, o);
216                 e2 = pointwise_multiply(tmp2, uniform_matrix<double>(second_hidden_nodes+1,1,1.0) - tmp2, trans(w3)*e3 );
217                 e1 = pointwise_multiply(tmp1, uniform_matrix<double>(first_hidden_nodes+1,1,1.0) - tmp1, trans(w2)*e2 );
218 
219                 // compute the new weight updates
220                 w3m = alpha * e3*trans(tmp2) + w3m*momentum;
221                 w2m = alpha * e2*trans(tmp1) + w2m*momentum;
222                 w1m = alpha * e1*trans(z)    + w1m*momentum;
223 
224                 // now update the weights
225                 w1 += w1m;
226                 w2 += w2m;
227                 w3 += w3m;
228             }
229         }
230 
231         template <typename EXP>
train(const matrix_exp<EXP> & example_in,double example_out)232         void train (
233             const matrix_exp<EXP>& example_in,
234             double example_out
235         )
236         {
237             matrix<double,1,1> e_out;
238             e_out(0) = example_out;
239             train(example_in,e_out);
240         }
241 
get_average_change()242         double get_average_change (
243         ) const
244         {
245             // sum up all the weight changes
246             double delta = sum(abs(w1m)) + sum(abs(w2m)) + sum(abs(w3m));
247 
248             // divide by the number of weights
249             delta /=  w1m.nr()*w1m.nc() +
250                 w2m.nr()*w2m.nc() +
251                 w3m.nr()*w3m.nc();
252 
253             return delta;
254         }
255 
swap(mlp_kernel_1 & item)256         void swap (
257             mlp_kernel_1& item
258         )
259         {
260             exchange(input_nodes, item.input_nodes);
261             exchange(first_hidden_nodes, item.first_hidden_nodes);
262             exchange(second_hidden_nodes, item.second_hidden_nodes);
263             exchange(output_nodes, item.output_nodes);
264             exchange(alpha, item.alpha);
265             exchange(momentum, item.momentum);
266 
267             w1.swap(item.w1);
268             w2.swap(item.w2);
269             w3.swap(item.w3);
270 
271             w1m.swap(item.w1m);
272             w2m.swap(item.w2m);
273             w3m.swap(item.w3m);
274 
275             // even swap the temporary matrices because this may ultimately result in
276             // fewer calls to new and delete.
277             e1.swap(item.e1);
278             e2.swap(item.e2);
279             e3.swap(item.e3);
280             z.swap(item.z);
281             tmp1.swap(item.tmp1);
282             tmp2.swap(item.tmp2);
283             o.swap(item.o);
284         }
285 
286 
287         friend void serialize (
288             const mlp_kernel_1& item,
289             std::ostream& out
290         );
291 
292         friend void deserialize (
293             mlp_kernel_1& item,
294             std::istream& in
295         );
296 
297     private:
298 
299         long input_nodes;
300         long first_hidden_nodes;
301         long second_hidden_nodes;
302         long output_nodes;
303         double alpha;
304         double momentum;
305 
306         matrix<double> w1;
307         matrix<double> w2;
308         matrix<double> w3;
309 
310         matrix<double> w1m;
311         matrix<double> w2m;
312         matrix<double> w3m;
313 
314 
315         rand rand_nums;
316 
317         // temporary storage
318         mutable matrix<double> e1, e2, e3;
319         mutable matrix<double> z, tmp1, tmp2, o;
320     };
321 
swap(mlp_kernel_1 & a,mlp_kernel_1 & b)322     inline void swap (
323         mlp_kernel_1& a,
324         mlp_kernel_1& b
325     ) { a.swap(b); }
326 
327 // ----------------------------------------------------------------------------------------
328 
serialize(const mlp_kernel_1 & item,std::ostream & out)329     inline void serialize (
330         const mlp_kernel_1& item,
331         std::ostream& out
332     )
333     {
334         try
335         {
336             serialize(item.input_nodes, out);
337             serialize(item.first_hidden_nodes, out);
338             serialize(item.second_hidden_nodes, out);
339             serialize(item.output_nodes, out);
340             serialize(item.alpha, out);
341             serialize(item.momentum, out);
342 
343             serialize(item.w1, out);
344             serialize(item.w2, out);
345             serialize(item.w3, out);
346 
347             serialize(item.w1m, out);
348             serialize(item.w2m, out);
349             serialize(item.w3m, out);
350         }
351         catch (serialization_error& e)
352         {
353             throw serialization_error(e.info + "\n   while serializing object of type mlp_kernel_1");
354         }
355     }
356 
deserialize(mlp_kernel_1 & item,std::istream & in)357     inline void deserialize (
358         mlp_kernel_1& item,
359         std::istream& in
360     )
361     {
362         try
363         {
364             deserialize(item.input_nodes, in);
365             deserialize(item.first_hidden_nodes, in);
366             deserialize(item.second_hidden_nodes, in);
367             deserialize(item.output_nodes, in);
368             deserialize(item.alpha, in);
369             deserialize(item.momentum, in);
370 
371             deserialize(item.w1, in);
372             deserialize(item.w2, in);
373             deserialize(item.w3, in);
374 
375             deserialize(item.w1m, in);
376             deserialize(item.w2m, in);
377             deserialize(item.w3m, in);
378 
379             item.z.set_size(item.input_nodes+1,1);
380         }
381         catch (serialization_error& e)
382         {
383             // give item a reasonable value since the deserialization failed
384             mlp_kernel_1(1,1).swap(item);
385             throw serialization_error(e.info + "\n   while deserializing object of type mlp_kernel_1");
386         }
387     }
388 
389 // ----------------------------------------------------------------------------------------
390 
391 }
392 
393 #endif // DLIB_MLp_KERNEL_1_
394 
395