1 // Copyright (C) 2007 Davis E. King (davis@dlib.net) 2 // License: Boost Software License See LICENSE.txt for the full license. 3 #ifndef DLIB_MLp_KERNEL_1_ 4 #define DLIB_MLp_KERNEL_1_ 5 6 #include "../algs.h" 7 #include "../serialize.h" 8 #include "../matrix.h" 9 #include "../rand.h" 10 #include "mlp_kernel_abstract.h" 11 #include <ctime> 12 #include <sstream> 13 14 namespace dlib 15 { 16 17 class mlp_kernel_1 : noncopyable 18 { 19 /*! 20 INITIAL VALUE 21 The network is initially initialized with random weights 22 23 CONVENTION 24 - input_layer_nodes() == input_nodes 25 - first_hidden_layer_nodes() == first_hidden_nodes 26 - second_hidden_layer_nodes() == second_hidden_nodes 27 - output_layer_nodes() == output_nodes 28 - get_alpha == alpha 29 - get_momentum() == momentum 30 31 32 - if (second_hidden_nodes == 0) then 33 - for all i and j: 34 - w1(i,j) == the weight on the link from node i in the first hidden layer 35 to input node j 36 - w3(i,j) == the weight on the link from node i in the output layer 37 to first hidden layer node j 38 - for all i and j: 39 - w1m == the momentum terms for w1 from the previous update 40 - w3m == the momentum terms for w3 from the previous update 41 - else 42 - for all i and j: 43 - w1(i,j) == the weight on the link from node i in the first hidden layer 44 to input node j 45 - w2(i,j) == the weight on the link from node i in the second hidden layer 46 to first hidden layer node j 47 - w3(i,j) == the weight on the link from node i in the output layer 48 to second hidden layer node j 49 - for all i and j: 50 - w1m == the momentum terms for w1 from the previous update 51 - w2m == the momentum terms for w2 from the previous update 52 - w3m == the momentum terms for w3 from the previous update 53 !*/ 54 55 public: 56 57 mlp_kernel_1 ( 58 long nodes_in_input_layer, 59 long nodes_in_first_hidden_layer, 60 long nodes_in_second_hidden_layer = 0, 61 long nodes_in_output_layer = 1, 62 double alpha_ = 0.1, 63 double momentum_ = 0.8 64 ) : input_nodes(nodes_in_input_layer)65 input_nodes(nodes_in_input_layer), 66 first_hidden_nodes(nodes_in_first_hidden_layer), 67 second_hidden_nodes(nodes_in_second_hidden_layer), 68 output_nodes(nodes_in_output_layer), 69 alpha(alpha_), 70 momentum(momentum_) 71 { 72 73 // seed the random number generator 74 std::ostringstream sout; 75 sout << time(0); 76 rand_nums.set_seed(sout.str()); 77 78 w1.set_size(first_hidden_nodes+1, input_nodes+1); 79 w1m.set_size(first_hidden_nodes+1, input_nodes+1); 80 z.set_size(input_nodes+1,1); 81 82 if (second_hidden_nodes != 0) 83 { 84 w2.set_size(second_hidden_nodes+1, first_hidden_nodes+1); 85 w3.set_size(output_nodes, second_hidden_nodes+1); 86 87 w2m.set_size(second_hidden_nodes+1, first_hidden_nodes+1); 88 w3m.set_size(output_nodes, second_hidden_nodes+1); 89 } 90 else 91 { 92 w3.set_size(output_nodes, first_hidden_nodes+1); 93 94 w3m.set_size(output_nodes, first_hidden_nodes+1); 95 } 96 97 reset(); 98 } 99 ~mlp_kernel_1()100 virtual ~mlp_kernel_1 ( 101 ) {} 102 reset()103 void reset ( 104 ) 105 { 106 // randomize the weights for the first layer 107 for (long r = 0; r < w1.nr(); ++r) 108 for (long c = 0; c < w1.nc(); ++c) 109 w1(r,c) = rand_nums.get_random_double(); 110 111 // randomize the weights for the second layer 112 for (long r = 0; r < w2.nr(); ++r) 113 for (long c = 0; c < w2.nc(); ++c) 114 w2(r,c) = rand_nums.get_random_double(); 115 116 // randomize the weights for the third layer 117 for (long r = 0; r < w3.nr(); ++r) 118 for (long c = 0; c < w3.nc(); ++c) 119 w3(r,c) = rand_nums.get_random_double(); 120 121 // zero all the momentum terms 122 set_all_elements(w1m,0); 123 set_all_elements(w2m,0); 124 set_all_elements(w3m,0); 125 } 126 input_layer_nodes()127 long input_layer_nodes ( 128 ) const { return input_nodes; } 129 first_hidden_layer_nodes()130 long first_hidden_layer_nodes ( 131 ) const { return first_hidden_nodes; } 132 second_hidden_layer_nodes()133 long second_hidden_layer_nodes ( 134 ) const { return second_hidden_nodes; } 135 output_layer_nodes()136 long output_layer_nodes ( 137 ) const { return output_nodes; } 138 get_alpha()139 double get_alpha ( 140 ) const { return alpha; } 141 get_momentum()142 double get_momentum ( 143 ) const { return momentum; } 144 145 template <typename EXP> operator()146 const matrix<double> operator() ( 147 const matrix_exp<EXP>& in 148 ) const 149 { 150 for (long i = 0; i < in.nr(); ++i) 151 z(i) = in(i); 152 // insert the bias 153 z(z.nr()-1) = -1; 154 155 tmp1 = sigmoid(w1*z); 156 // insert the bias 157 tmp1(tmp1.nr()-1) = -1; 158 159 if (second_hidden_nodes == 0) 160 { 161 return sigmoid(w3*tmp1); 162 } 163 else 164 { 165 tmp2 = sigmoid(w2*tmp1); 166 // insert the bias 167 tmp2(tmp2.nr()-1) = -1; 168 169 return sigmoid(w3*tmp2); 170 } 171 } 172 173 template <typename EXP1, typename EXP2> train(const matrix_exp<EXP1> & example_in,const matrix_exp<EXP2> & example_out)174 void train ( 175 const matrix_exp<EXP1>& example_in, 176 const matrix_exp<EXP2>& example_out 177 ) 178 { 179 for (long i = 0; i < example_in.nr(); ++i) 180 z(i) = example_in(i); 181 // insert the bias 182 z(z.nr()-1) = -1; 183 184 tmp1 = sigmoid(w1*z); 185 // insert the bias 186 tmp1(tmp1.nr()-1) = -1; 187 188 189 if (second_hidden_nodes == 0) 190 { 191 o = sigmoid(w3*tmp1); 192 193 // now compute the errors and propagate them backwards though the network 194 e3 = pointwise_multiply(example_out-o, uniform_matrix<double>(output_nodes,1,1.0)-o, o); 195 e1 = pointwise_multiply(tmp1, uniform_matrix<double>(first_hidden_nodes+1,1,1.0) - tmp1, trans(w3)*e3 ); 196 197 // compute the new weight updates 198 w3m = alpha * e3*trans(tmp1) + w3m*momentum; 199 w1m = alpha * e1*trans(z) + w1m*momentum; 200 201 // now update the weights 202 w1 += w1m; 203 w3 += w3m; 204 } 205 else 206 { 207 tmp2 = sigmoid(w2*tmp1); 208 // insert the bias 209 tmp2(tmp2.nr()-1) = -1; 210 211 o = sigmoid(w3*tmp2); 212 213 214 // now compute the errors and propagate them backwards though the network 215 e3 = pointwise_multiply(example_out-o, uniform_matrix<double>(output_nodes,1,1.0)-o, o); 216 e2 = pointwise_multiply(tmp2, uniform_matrix<double>(second_hidden_nodes+1,1,1.0) - tmp2, trans(w3)*e3 ); 217 e1 = pointwise_multiply(tmp1, uniform_matrix<double>(first_hidden_nodes+1,1,1.0) - tmp1, trans(w2)*e2 ); 218 219 // compute the new weight updates 220 w3m = alpha * e3*trans(tmp2) + w3m*momentum; 221 w2m = alpha * e2*trans(tmp1) + w2m*momentum; 222 w1m = alpha * e1*trans(z) + w1m*momentum; 223 224 // now update the weights 225 w1 += w1m; 226 w2 += w2m; 227 w3 += w3m; 228 } 229 } 230 231 template <typename EXP> train(const matrix_exp<EXP> & example_in,double example_out)232 void train ( 233 const matrix_exp<EXP>& example_in, 234 double example_out 235 ) 236 { 237 matrix<double,1,1> e_out; 238 e_out(0) = example_out; 239 train(example_in,e_out); 240 } 241 get_average_change()242 double get_average_change ( 243 ) const 244 { 245 // sum up all the weight changes 246 double delta = sum(abs(w1m)) + sum(abs(w2m)) + sum(abs(w3m)); 247 248 // divide by the number of weights 249 delta /= w1m.nr()*w1m.nc() + 250 w2m.nr()*w2m.nc() + 251 w3m.nr()*w3m.nc(); 252 253 return delta; 254 } 255 swap(mlp_kernel_1 & item)256 void swap ( 257 mlp_kernel_1& item 258 ) 259 { 260 exchange(input_nodes, item.input_nodes); 261 exchange(first_hidden_nodes, item.first_hidden_nodes); 262 exchange(second_hidden_nodes, item.second_hidden_nodes); 263 exchange(output_nodes, item.output_nodes); 264 exchange(alpha, item.alpha); 265 exchange(momentum, item.momentum); 266 267 w1.swap(item.w1); 268 w2.swap(item.w2); 269 w3.swap(item.w3); 270 271 w1m.swap(item.w1m); 272 w2m.swap(item.w2m); 273 w3m.swap(item.w3m); 274 275 // even swap the temporary matrices because this may ultimately result in 276 // fewer calls to new and delete. 277 e1.swap(item.e1); 278 e2.swap(item.e2); 279 e3.swap(item.e3); 280 z.swap(item.z); 281 tmp1.swap(item.tmp1); 282 tmp2.swap(item.tmp2); 283 o.swap(item.o); 284 } 285 286 287 friend void serialize ( 288 const mlp_kernel_1& item, 289 std::ostream& out 290 ); 291 292 friend void deserialize ( 293 mlp_kernel_1& item, 294 std::istream& in 295 ); 296 297 private: 298 299 long input_nodes; 300 long first_hidden_nodes; 301 long second_hidden_nodes; 302 long output_nodes; 303 double alpha; 304 double momentum; 305 306 matrix<double> w1; 307 matrix<double> w2; 308 matrix<double> w3; 309 310 matrix<double> w1m; 311 matrix<double> w2m; 312 matrix<double> w3m; 313 314 315 rand rand_nums; 316 317 // temporary storage 318 mutable matrix<double> e1, e2, e3; 319 mutable matrix<double> z, tmp1, tmp2, o; 320 }; 321 swap(mlp_kernel_1 & a,mlp_kernel_1 & b)322 inline void swap ( 323 mlp_kernel_1& a, 324 mlp_kernel_1& b 325 ) { a.swap(b); } 326 327 // ---------------------------------------------------------------------------------------- 328 serialize(const mlp_kernel_1 & item,std::ostream & out)329 inline void serialize ( 330 const mlp_kernel_1& item, 331 std::ostream& out 332 ) 333 { 334 try 335 { 336 serialize(item.input_nodes, out); 337 serialize(item.first_hidden_nodes, out); 338 serialize(item.second_hidden_nodes, out); 339 serialize(item.output_nodes, out); 340 serialize(item.alpha, out); 341 serialize(item.momentum, out); 342 343 serialize(item.w1, out); 344 serialize(item.w2, out); 345 serialize(item.w3, out); 346 347 serialize(item.w1m, out); 348 serialize(item.w2m, out); 349 serialize(item.w3m, out); 350 } 351 catch (serialization_error& e) 352 { 353 throw serialization_error(e.info + "\n while serializing object of type mlp_kernel_1"); 354 } 355 } 356 deserialize(mlp_kernel_1 & item,std::istream & in)357 inline void deserialize ( 358 mlp_kernel_1& item, 359 std::istream& in 360 ) 361 { 362 try 363 { 364 deserialize(item.input_nodes, in); 365 deserialize(item.first_hidden_nodes, in); 366 deserialize(item.second_hidden_nodes, in); 367 deserialize(item.output_nodes, in); 368 deserialize(item.alpha, in); 369 deserialize(item.momentum, in); 370 371 deserialize(item.w1, in); 372 deserialize(item.w2, in); 373 deserialize(item.w3, in); 374 375 deserialize(item.w1m, in); 376 deserialize(item.w2m, in); 377 deserialize(item.w3m, in); 378 379 item.z.set_size(item.input_nodes+1,1); 380 } 381 catch (serialization_error& e) 382 { 383 // give item a reasonable value since the deserialization failed 384 mlp_kernel_1(1,1).swap(item); 385 throw serialization_error(e.info + "\n while deserializing object of type mlp_kernel_1"); 386 } 387 } 388 389 // ---------------------------------------------------------------------------------------- 390 391 } 392 393 #endif // DLIB_MLp_KERNEL_1_ 394 395