1 // OpenNN: Open Neural Networks Library
2 // www.opennn.net
3 //
4 // L E A R N I N G R A T E A L G O R I T H M C L A S S
5 //
6 // Artificial Intelligence Techniques SL
7 // artelnics@artelnics.com
8
9 #include "learning_rate_algorithm.h"
10
11 namespace OpenNN
12 {
13
14 /// Default constructor.
15 /// It creates a learning rate algorithm object not associated to any loss index object.
16 /// It also initializes the class members to their default values.
17
LearningRateAlgorithm()18 LearningRateAlgorithm::LearningRateAlgorithm()
19 : loss_index_pointer(nullptr)
20 {
21 set_default();
22 }
23
24
25 /// Destructor.
26 /// It creates a learning rate algorithm associated to a loss index.
27 /// It also initializes the class members to their default values.
28 /// @param new_loss_index_pointer Pointer to a loss index object.
29
LearningRateAlgorithm(LossIndex * new_loss_index_pointer)30 LearningRateAlgorithm::LearningRateAlgorithm(LossIndex* new_loss_index_pointer)
31 : loss_index_pointer(new_loss_index_pointer)
32 {
33 set_default();
34 }
35
36
37 /// Destructor.
38
~LearningRateAlgorithm()39 LearningRateAlgorithm::~LearningRateAlgorithm()
40 {
41 delete non_blocking_thread_pool;
42 delete thread_pool_device;
43 }
44
45
46 /// Returns a pointer to the loss index object
47 /// to which the learning rate algorithm is associated.
48 /// If the loss index pointer is nullptr, this method throws an exception.
49
get_loss_index_pointer() const50 LossIndex* LearningRateAlgorithm::get_loss_index_pointer() const
51 {
52 #ifdef __OPENNN_DEBUG__
53
54 if(!loss_index_pointer)
55 {
56 ostringstream buffer;
57
58 buffer << "OpenNN Exception: LearningRateAlgorithm class.\n"
59 << "LossIndex* get_loss_index_pointer() const method.\n"
60 << "Loss index pointer is nullptr.\n";
61
62 throw logic_error(buffer.str());
63 }
64
65 #endif
66
67 return loss_index_pointer;
68 }
69
70
71 /// Returns true if this learning rate algorithm has an associated loss index,
72 /// and false otherwise.
73
has_loss_index() const74 bool LearningRateAlgorithm::has_loss_index() const
75 {
76 if(loss_index_pointer)
77 {
78 return true;
79 }
80 else
81 {
82 return false;
83 }
84 }
85
86
87 /// Returns the learning rate method used for training.
88
get_learning_rate_method() const89 const LearningRateAlgorithm::LearningRateMethod& LearningRateAlgorithm::get_learning_rate_method() const
90 {
91 return learning_rate_method;
92 }
93
94
95 /// Returns a string with the name of the learning rate method to be used.
96
write_learning_rate_method() const97 string LearningRateAlgorithm::write_learning_rate_method() const
98 {
99 switch(learning_rate_method)
100 {
101 case GoldenSection:
102 return "GoldenSection";
103
104 case BrentMethod:
105 return "BrentMethod";
106 }
107
108 return string();
109 }
110
111
get_learning_rate_tolerance() const112 const type& LearningRateAlgorithm::get_learning_rate_tolerance() const
113 {
114 return learning_rate_tolerance;
115 }
116
117
118 /// Returns true if messages from this class can be displayed on the screen, or false if messages from
119 /// this class can't be displayed on the screen.
120
get_display() const121 const bool& LearningRateAlgorithm::get_display() const
122 {
123 return display;
124 }
125
126
127 /// Sets the loss index pointer to nullptr.
128 /// It also sets the rest of members to their default values.
129
set()130 void LearningRateAlgorithm::set()
131 {
132 loss_index_pointer = nullptr;
133
134 set_default();
135 }
136
137
138 /// Sets a new loss index pointer.
139 /// It also sets the rest of members to their default values.
140 /// @param new_loss_index_pointer Pointer to a loss index object.
141
set(LossIndex * new_loss_index_pointer)142 void LearningRateAlgorithm::set(LossIndex* new_loss_index_pointer)
143 {
144 loss_index_pointer = new_loss_index_pointer;
145
146 set_default();
147 }
148
149
150 /// Sets the members of the learning rate algorithm to their default values.
151
set_default()152 void LearningRateAlgorithm::set_default()
153 {
154 delete non_blocking_thread_pool;
155 delete thread_pool_device;
156
157 const int n = omp_get_max_threads();
158 non_blocking_thread_pool = new NonBlockingThreadPool(n);
159 thread_pool_device = new ThreadPoolDevice(non_blocking_thread_pool, n);
160
161 // TRAINING OPERATORS
162
163 learning_rate_method = BrentMethod;
164
165 // TRAINING PARAMETERS
166
167 learning_rate_tolerance = static_cast<type>(1.0e-3);
168 loss_tolerance = static_cast<type>(1.0e-3);
169 }
170
171
172 /// Sets a pointer to a loss index object to be associated to the optimization algorithm.
173 /// @param new_loss_index_pointer Pointer to a loss index object.
174
set_loss_index_pointer(LossIndex * new_loss_index_pointer)175 void LearningRateAlgorithm::set_loss_index_pointer(LossIndex* new_loss_index_pointer)
176 {
177 loss_index_pointer = new_loss_index_pointer;
178 }
179
180
set_threads_number(const int & new_threads_number)181 void LearningRateAlgorithm::set_threads_number(const int& new_threads_number)
182 {
183 if(non_blocking_thread_pool != nullptr) delete this->non_blocking_thread_pool;
184 if(thread_pool_device != nullptr) delete this->thread_pool_device;
185
186 non_blocking_thread_pool = new NonBlockingThreadPool(new_threads_number);
187 thread_pool_device = new ThreadPoolDevice(non_blocking_thread_pool, new_threads_number);
188 }
189
190
191 /// Sets a new learning rate method to be used for training.
192 /// @param new_learning_rate_method Learning rate method.
193
set_learning_rate_method(const LearningRateAlgorithm::LearningRateMethod & new_learning_rate_method)194 void LearningRateAlgorithm::set_learning_rate_method(
195 const LearningRateAlgorithm::LearningRateMethod& new_learning_rate_method)
196 {
197 learning_rate_method = new_learning_rate_method;
198 }
199
200
201 /// Sets the method for obtaining the learning rate from a string with the name of the method.
202 /// @param new_learning_rate_method Name of learning rate method("Fixed", "GoldenSection", "BrentMethod").
203
set_learning_rate_method(const string & new_learning_rate_method)204 void LearningRateAlgorithm::set_learning_rate_method(const string& new_learning_rate_method)
205 {
206 if(new_learning_rate_method == "GoldenSection")
207 {
208 learning_rate_method = GoldenSection;
209 }
210 else if(new_learning_rate_method == "BrentMethod")
211 {
212 learning_rate_method = BrentMethod;
213 }
214 else
215 {
216 ostringstream buffer;
217
218 buffer << "OpenNN Exception: LearningRateAlgorithm class.\n"
219 << "void set_method(const string&) method.\n"
220 << "Unknown learning rate method: " << new_learning_rate_method << ".\n";
221
222 throw logic_error(buffer.str());
223 }
224 }
225
226
227 /// Sets a new tolerance value to be used in line minimization.
228 /// @param new_learning_rate_tolerance Tolerance value in line minimization.
229
set_learning_rate_tolerance(const type & new_learning_rate_tolerance)230 void LearningRateAlgorithm::set_learning_rate_tolerance(const type& new_learning_rate_tolerance)
231 {
232 #ifdef __OPENNN_DEBUG__
233
234 if(new_learning_rate_tolerance <= static_cast<type>(0.0))
235 {
236 ostringstream buffer;
237
238 buffer << "OpenNN Exception: LearningRateAlgorithm class.\n"
239 << "void set_learning_rate_tolerance(const type&) method.\n"
240 << "Tolerance must be greater than 0.\n";
241
242 throw logic_error(buffer.str());
243 }
244
245 #endif
246
247 // Set loss tolerance
248
249 learning_rate_tolerance = new_learning_rate_tolerance;
250 }
251
252
253 /// Sets a new display value.
254 /// If it is set to true messages from this class are to be displayed on the screen;
255 /// if it is set to false messages from this class are not to be displayed on the screen.
256 /// @param new_display Display value.
257
set_display(const bool & new_display)258 void LearningRateAlgorithm::set_display(const bool& new_display)
259 {
260 display = new_display;
261 }
262
263
264 /// Returns a vector with two elements:
265 ///(i) the learning rate calculated by means of the corresponding algorithm, and
266 ///(ii) the loss for that learning rate.
267 /// @param loss Initial Performance value.
268 /// @param training_direction Initial training direction.
269 /// @param initial_learning_rate Initial learning rate to start the algorithm.
270
calculate_directional_point(const DataSet::Batch & batch,NeuralNetwork::ForwardPropagation & forward_propagation,LossIndex::BackPropagation & back_propagation,OptimizationAlgorithm::OptimizationData & optimization_data) const271 pair<type,type> LearningRateAlgorithm::calculate_directional_point(
272 const DataSet::Batch& batch,
273 NeuralNetwork::ForwardPropagation& forward_propagation,
274 LossIndex::BackPropagation& back_propagation,
275 OptimizationAlgorithm::OptimizationData& optimization_data) const
276 {
277 const NeuralNetwork* neural_network_pointer = loss_index_pointer->get_neural_network_pointer();
278
279 #ifdef __OPENNN_DEBUG__
280
281 if(loss_index_pointer == nullptr)
282 {
283 ostringstream buffer;
284
285 buffer << "OpenNN Error: LearningRateAlgorithm class.\n"
286 << "pair<type, 1> calculate_directional_point() const method.\n"
287 << "Pointer to loss index is nullptr.\n";
288
289 throw logic_error(buffer.str());
290 }
291
292 if(neural_network_pointer == nullptr)
293 {
294 ostringstream buffer;
295
296 buffer << "OpenNN Error: LearningRateAlgorithm class.\n"
297 << "Tensor<type, 1> calculate_directional_point() const method.\n"
298 << "Pointer to neural network is nullptr.\n";
299
300 throw logic_error(buffer.str());
301 }
302
303 if(thread_pool_device == nullptr)
304 {
305 ostringstream buffer;
306
307 buffer << "OpenNN Error: LearningRateAlgorithm class.\n"
308 << "pair<type, 1> calculate_directional_point() const method.\n"
309 << "Pointer to thread pool device is nullptr.\n";
310
311 throw logic_error(buffer.str());
312 }
313
314 #endif
315
316 ostringstream buffer;
317
318 const type regularization_weight = loss_index_pointer->get_regularization_weight();
319
320
321 // Bracket minimum
322
323 Triplet triplet = calculate_bracketing_triplet(batch,
324 forward_propagation,
325 back_propagation,
326 optimization_data);
327
328 try
329 {
330 triplet.check();
331 }
332 catch(const logic_error& error)
333 {
334 //cout << "Triplet bracketing" << endl;
335
336 //cout << error.what() << endl;
337
338 return triplet.minimum();
339 }
340
341 pair<type, type> V;
342
343 // Reduce the interval
344
345 while(fabs(triplet.A.first-triplet.B.first) > learning_rate_tolerance
346 || fabs(triplet.A.second-triplet.B.second) > loss_tolerance)
347 {
348 try
349 {
350 switch(learning_rate_method)
351 {
352 case GoldenSection: V.first = calculate_golden_section_learning_rate(triplet); break;
353
354 case BrentMethod: V.first = calculate_Brent_method_learning_rate(triplet); break;
355 }
356 }
357 catch(const logic_error& error)
358 {
359 cout << "Learning rate" << endl;
360
361 //cout << error.what() << endl;
362
363 return triplet.minimum();
364 }
365
366 // Calculate loss for V
367
368 optimization_data.potential_parameters.device(*thread_pool_device)
369 = optimization_data.parameters + optimization_data.training_direction*V.first;
370
371 neural_network_pointer->forward_propagate(batch, optimization_data.potential_parameters, forward_propagation);
372
373 loss_index_pointer->calculate_error(batch, forward_propagation, back_propagation);
374
375 const type regularization = loss_index_pointer->calculate_regularization(optimization_data.potential_parameters);
376
377 V.second = back_propagation.error + regularization_weight*regularization;
378
379 // Update points
380
381 if(V.first <= triplet.U.first)
382 {
383 if(V.second >= triplet.U.second)
384 {
385 triplet.A = V;
386 }
387 else if(V.second <= triplet.U.second)
388 {
389 triplet.B = triplet.U;
390 triplet.U = V;
391 }
392 }
393 else if(V.first >= triplet.U.first)
394 {
395 if(V.second >= triplet.U.second)
396 {
397 triplet.B = V;
398 }
399 else if(V.second <= triplet.U.second)
400 {
401 triplet.A = triplet.U;
402 triplet.U = V;
403 }
404 }
405 else
406 {
407 buffer << "OpenNN Exception: LearningRateAlgorithm class.\n"
408 << "Tensor<type, 1> calculate_Brent_method_directional_point() const method.\n"
409 << "Unknown set:\n"
410 << "A = (" << triplet.A.first << "," << triplet.A.second << ")\n"
411 << "B = (" << triplet.B.first << "," << triplet.B.second << ")\n"
412 << "U = (" << triplet.U.first << "," << triplet.U.second << ")\n"
413 << "V = (" << V.first << "," << V.second << ")\n";
414
415 throw logic_error(buffer.str());
416 }
417
418 // Check triplet
419
420 try
421 {
422 triplet.check();
423 }
424 catch(const logic_error& error)
425 {
426 //cout << "Triplet reduction" << endl;
427
428 //cout << error.what() << endl;
429
430 return triplet.minimum();
431 }
432 }
433
434 return triplet.U;
435
436
437 /*
438 catch(range_error& e) // Interval is of length 0
439 {
440 cout << "Interval is of length 0" << endl;
441 cerr << e.what() << endl;
442
443 pair<type, type> A;
444 A.first = 0;
445 A.second = loss;
446
447 return A;
448 }
449 catch(const logic_error& e)
450 {
451 cout << "Other exception" << endl;
452 cerr << e.what() << endl;
453
454 pair<type, type> X;
455 X.first = optimization_data.initial_learning_rate;
456
457 optimization_data.potential_parameters.device(*thread_pool_device)
458 = optimization_data.parameters + optimization_data.training_direction*X.first;
459
460 neural_network_pointer->forward_propagate(batch, optimization_data.potential_parameters, forward_propagation);
461
462 loss_index_pointer->calculate_error(batch, forward_propagation, back_propagation);
463
464 const type regularization = loss_index_pointer->calculate_regularization(optimization_data.potential_parameters);
465
466 X.second = back_propagation.error + regularization_weight*regularization;
467
468 if(X.second > loss)
469 {
470 X.first = 0;
471 X.second = 0;
472 }
473
474 return X;
475 }
476 */
477
478
479
480 return pair<type,type>();
481 }
482
483
484 /// Returns bracketing triplet.
485 /// This algorithm is used by line minimization algorithms.
486 /// @param loss Initial Performance value.
487 /// @param training_direction Initial training direction.
488 /// @param initial_learning_rate Initial learning rate to start the algorithm.
489
calculate_bracketing_triplet(const DataSet::Batch & batch,NeuralNetwork::ForwardPropagation & forward_propagation,LossIndex::BackPropagation & back_propagation,OptimizationAlgorithm::OptimizationData & optimization_data) const490 LearningRateAlgorithm::Triplet LearningRateAlgorithm::calculate_bracketing_triplet(
491 const DataSet::Batch& batch,
492 NeuralNetwork::ForwardPropagation& forward_propagation,
493 LossIndex::BackPropagation& back_propagation,
494 OptimizationAlgorithm::OptimizationData& optimization_data) const
495 {
496 const NeuralNetwork* neural_network_pointer = loss_index_pointer->get_neural_network_pointer();
497
498 #ifdef __OPENNN_DEBUG__
499
500 ostringstream buffer;
501
502 if(loss_index_pointer == nullptr)
503 {
504 buffer << "OpenNN Error: LearningRateAlgorithm class.\n"
505 << "Triplet calculate_bracketing_triplet() const method.\n"
506 << "Pointer to loss index is nullptr.\n";
507
508 throw logic_error(buffer.str());
509 }
510
511 if(neural_network_pointer == nullptr)
512 {
513 buffer << "OpenNN Error: LearningRateAlgorithm class.\n"
514 << "Triplet calculate_bracketing_triplet() const method.\n"
515 << "Pointer to neural network is nullptr.\n";
516
517 throw logic_error(buffer.str());
518 }
519
520 if(thread_pool_device == nullptr)
521 {
522 buffer << "OpenNN Error: LearningRateAlgorithm class.\n"
523 << "Triplet calculate_bracketing_triplet() const method.\n"
524 << "Pointer to thread pool device is nullptr.\n";
525
526 throw logic_error(buffer.str());
527 }
528
529 if(is_zero(optimization_data.training_direction))
530 {
531 buffer << "OpenNN Error: LearningRateAlgorithm class.\n"
532 << "Triplet calculate_bracketing_triplet() const method.\n"
533 << "Training direction is zero.\n";
534
535 throw logic_error(buffer.str());
536 }
537
538 if(optimization_data.initial_learning_rate < numeric_limits<type>::min())
539 {
540 buffer << "OpenNN Error: LearningRateAlgorithm class.\n"
541 << "Triplet calculate_bracketing_triplet() const method.\n"
542 << "Initial learning rate is zero.\n";
543
544 throw logic_error(buffer.str());
545 }
546
547 #endif
548
549 const type loss = back_propagation.loss;
550
551 const type regularization_weight = loss_index_pointer->get_regularization_weight();
552
553 Triplet triplet;
554
555 // Left point
556
557 triplet.A.first = 0;
558 triplet.A.second = loss;
559
560 // Right point
561
562 Index count = 0;
563
564 do
565 {
566 count++;
567
568 triplet.B.first = optimization_data.initial_learning_rate*count;
569
570 optimization_data.potential_parameters.device(*thread_pool_device)
571 = optimization_data.parameters + optimization_data.training_direction*triplet.B.first;
572
573 neural_network_pointer->forward_propagate(batch, optimization_data.potential_parameters, forward_propagation);
574
575 loss_index_pointer->calculate_error(batch, forward_propagation, back_propagation);
576
577 const type regularization = loss_index_pointer->calculate_regularization(optimization_data.potential_parameters);
578
579 triplet.B.second = back_propagation.error + regularization_weight*regularization;
580
581 } while(abs(triplet.A.second - triplet.B.second) < numeric_limits<type>::min());
582
583
584 if(triplet.A.second > triplet.B.second)
585 {
586 triplet.U = triplet.B;
587
588 triplet.B.first *= golden_ratio;
589
590 optimization_data.potential_parameters.device(*thread_pool_device)
591 = optimization_data.parameters + optimization_data.training_direction*triplet.B.first;
592
593 neural_network_pointer->forward_propagate(batch, optimization_data.potential_parameters, forward_propagation);
594
595 loss_index_pointer->calculate_error(batch, forward_propagation, back_propagation);
596
597 const type regularization = loss_index_pointer->calculate_regularization(optimization_data.potential_parameters);
598
599 triplet.B.second = back_propagation.error + regularization_weight*regularization;
600
601 while(triplet.U.second > triplet.B.second)
602 {
603 triplet.A = triplet.U;
604 triplet.U = triplet.B;
605
606 triplet.B.first *= golden_ratio;
607
608 optimization_data.potential_parameters.device(*thread_pool_device) = optimization_data.parameters + optimization_data.training_direction*triplet.B.first;
609
610 neural_network_pointer->forward_propagate(batch, optimization_data.potential_parameters, forward_propagation);
611
612 loss_index_pointer->calculate_error(batch, forward_propagation, back_propagation);
613
614 const type regularization = loss_index_pointer->calculate_regularization(optimization_data.potential_parameters);
615
616 triplet.B.second = back_propagation.error + regularization_weight*regularization;
617 }
618 }
619 else if(triplet.A.second < triplet.B.second)
620 {
621 triplet.U.first = triplet.A.first + (triplet.B.first - triplet.A.first)*static_cast<type>(0.382);
622
623 optimization_data.potential_parameters.device(*thread_pool_device) = optimization_data.parameters + optimization_data.training_direction*triplet.U.first;
624
625 neural_network_pointer->forward_propagate(batch, optimization_data.potential_parameters, forward_propagation);
626
627 loss_index_pointer->calculate_error(batch, forward_propagation, back_propagation);
628
629 const type regularization = loss_index_pointer->calculate_regularization(optimization_data.potential_parameters);
630
631 triplet.U.second = back_propagation.error + regularization_weight*regularization;
632
633 while(triplet.A.second < triplet.U.second)
634 {
635 triplet.B = triplet.U;
636
637 triplet.U.first = triplet.A.first + (triplet.B.first-triplet.A.first)*static_cast<type>(0.382);
638
639 optimization_data.potential_parameters.device(*thread_pool_device)
640 = optimization_data.parameters + optimization_data.training_direction*triplet.U.first;
641
642 neural_network_pointer->forward_propagate(batch, optimization_data.potential_parameters, forward_propagation);
643
644 loss_index_pointer->calculate_error(batch, forward_propagation, back_propagation);
645
646 const type regularization = loss_index_pointer->calculate_regularization(optimization_data.potential_parameters);
647
648 triplet.U.second = back_propagation.error + regularization_weight*regularization;
649
650 if(triplet.U.first - triplet.A.first <= learning_rate_tolerance)
651 {
652 triplet.U = triplet.A;
653 triplet.B = triplet.A;
654
655 return triplet;
656 }
657 }
658 }
659
660 return triplet;
661 }
662
663 /// Calculates the golden section point within a minimum interval defined by three points.
664 /// @param triplet Triplet containing a minimum.
665
calculate_golden_section_learning_rate(const Triplet & triplet) const666 type LearningRateAlgorithm::calculate_golden_section_learning_rate(const Triplet& triplet) const
667 {
668 type learning_rate;
669
670 const type middle = triplet.A.first + static_cast<type>(0.5)*(triplet.B.first - triplet.A.first);
671
672 if(triplet.U.first < middle)
673 {
674 learning_rate = triplet.A.first + static_cast<type>(0.618)*(triplet.B.first - triplet.A.first);
675 }
676 else
677 {
678 learning_rate = triplet.A.first + static_cast<type>(0.382)*(triplet.B.first - triplet.A.first);
679 }
680
681 #ifdef __OPENNN_DEBUG__
682
683 if(learning_rate < triplet.A.first)
684 {
685 ostringstream buffer;
686
687 buffer << "OpenNN Error: LearningRateAlgorithm class.\n"
688 << "type calculate_golden_section_learning_rate(const Triplet&) const method.\n"
689 << "Learning rate(" << learning_rate << ") is less than left point("
690 << triplet.A.first << ").\n";
691
692 throw logic_error(buffer.str());
693 }
694
695 if(learning_rate > triplet.B.first)
696 {
697 ostringstream buffer;
698
699 buffer << "OpenNN Error: LearningRateAlgorithm class.\n"
700 << "type calculate_golden_section_learning_rate(const Triplet&) const method.\n"
701 << "Learning rate(" << learning_rate << ") is greater than right point("
702 << triplet.B.first << ").\n";
703
704 throw logic_error(buffer.str());
705 }
706
707 #endif
708
709 return learning_rate;
710 }
711
712
713 /// Returns the minimimal learning rate of a parabola defined by three directional points.
714 /// @param triplet Triplet containing a minimum.
715
calculate_Brent_method_learning_rate(const Triplet & triplet) const716 type LearningRateAlgorithm::calculate_Brent_method_learning_rate(const Triplet& triplet) const
717 {
718 const type a = triplet.A.first;
719 const type u = triplet.U.first;
720 const type b = triplet.B.first;
721
722 const type fa = triplet.A.second;
723 const type fu = triplet.U.second;
724 const type fb = triplet.B.second;
725
726
727 type numerator = (u-a)*(u-a)*(fu-fb) - (u-b)*(u-b)*(fu-fa);
728
729 type denominator = (u-a)*(fu-fb) - (u-b)*(fu-fa);
730
731 return u - 0.5*numerator/denominator;
732
733 /*
734 const type c = -(triplet.A.second*(triplet.U.first-triplet.B.first)
735 + triplet.U.second*(triplet.B.first-triplet.A.first)
736 + triplet.B.second*(triplet.A.first-triplet.U.first))
737 /((triplet.A.first-triplet.U.first)*(triplet.U.first-triplet.B.first)*(triplet.B.first-triplet.A.first));
738
739 if(abs(c) < numeric_limits<type>::min())
740 {
741 ostringstream buffer;
742
743 buffer << "OpenNN Exception: LearningRateAlgorithm class.\n"
744 << "type calculate_Brent_method_learning_rate(const Triplet&) const method.\n"
745 << "Parabola cannot be constructed.\n";
746
747 throw logic_error(buffer.str());
748 }
749 else if(c < 0)
750 {
751 ostringstream buffer;
752
753 buffer << "OpenNN Exception: LearningRateAlgorithm class.\n"
754 << "type calculate_Brent_method_learning_rate(const Triplet&) const method.\n"
755 << "Parabola does not have a minimum but a maximum.\n";
756
757 throw logic_error(buffer.str());
758 }
759
760 const type b = (triplet.A.second*(triplet.U.first*triplet.U.first-triplet.B.first*triplet.B.first)
761 + triplet.U.second*(triplet.B.first*triplet.B.first-triplet.A.first*triplet.A.first)
762 + triplet.B.second*(triplet.A.first*triplet.A.first-triplet.U.first*triplet.U.first))
763 /((triplet.A.first-triplet.U.first)*(triplet.U.first-triplet.B.first)*(triplet.B.first-triplet.A.first));
764
765 const type Brent_method_learning_rate = -b/(static_cast<type>(2.0)*c);
766
767 if(Brent_method_learning_rate < triplet.A.first || Brent_method_learning_rate > triplet.B.first)
768 {
769 ostringstream buffer;
770
771 buffer << "OpenNN Exception: LearningRateAlgorithm class.\n"
772 << "type calculate_Brent_method_learning_rate() const method.\n"
773 << "Brent method learning rate is not inside interval.\n"
774 << "Interval:(" << triplet.A.first << "," << triplet.B.first << ")\n"
775 << "Brent method learning rate: " << Brent_method_learning_rate << endl;
776
777
778 throw logic_error(buffer.str());
779 }
780
781 return Brent_method_learning_rate;
782 */
783 }
784
785
786 /// Serializes the learning rate algorithm object into a XML document of the TinyXML library
787 /// without keep the DOM tree in memory.
788 /// See the OpenNN manual for more information about the format of this document.
789
write_XML(tinyxml2::XMLPrinter & file_stream) const790 void LearningRateAlgorithm::write_XML(tinyxml2::XMLPrinter& file_stream) const
791 {
792 ostringstream buffer;
793
794 // Learning rate algorithm
795
796 file_stream.OpenElement("LearningRateAlgorithm");
797
798 // Learning rate method
799
800 file_stream.OpenElement("LearningRateMethod");
801
802 file_stream.PushText(write_learning_rate_method().c_str());
803
804 file_stream.CloseElement();
805
806 // Learning rate tolerance
807
808 file_stream.OpenElement("LearningRateTolerance");
809
810 buffer.str("");
811 buffer << learning_rate_tolerance;
812
813 file_stream.PushText(buffer.str().c_str());
814
815 file_stream.CloseElement();
816
817 // Learning rate algorithm (end tag)
818
819 file_stream.CloseElement();
820 }
821
822
823 /// Loads a learning rate algorithm object from a XML-type file.
824 /// Please mind about the file format, wich is specified in the manual.
825 /// @param document TinyXML document with the learning rate algorithm members.
826
from_XML(const tinyxml2::XMLDocument & document)827 void LearningRateAlgorithm::from_XML(const tinyxml2::XMLDocument& document)
828 {
829 const tinyxml2::XMLElement* root_element = document.FirstChildElement("LearningRateAlgorithm");
830
831 if(!root_element)
832 {
833 ostringstream buffer;
834
835 buffer << "OpenNN Exception: LearningRateAlgorithm class.\n"
836 << "void from_XML(const tinyxml2::XMLDocument&) method.\n"
837 << "Learning rate algorithm element is nullptr.\n";
838
839 throw logic_error(buffer.str());
840 }
841
842 // Learning rate method
843 {
844 const tinyxml2::XMLElement* element = root_element->FirstChildElement("LearningRateMethod");
845
846 if(element)
847 {
848 string new_learning_rate_method = element->GetText();
849
850 try
851 {
852 set_learning_rate_method(new_learning_rate_method);
853 }
854 catch(const logic_error& e)
855 {
856 cerr << e.what() << endl;
857 }
858 }
859 }
860
861 // Learning rate tolerance
862 {
863 const tinyxml2::XMLElement* element = root_element->FirstChildElement("LearningRateTolerance");
864
865 if(element)
866 {
867 const type new_learning_rate_tolerance = static_cast<type>(atof(element->GetText()));
868
869 try
870 {
871 set_learning_rate_tolerance(new_learning_rate_tolerance);
872 }
873 catch(const logic_error& e)
874 {
875 cerr << e.what() << endl;
876 }
877 }
878 }
879
880 // Display warnings
881 {
882 const tinyxml2::XMLElement* element = root_element->FirstChildElement("Display");
883
884 if(element)
885 {
886 const string new_display = element->GetText();
887
888 try
889 {
890 set_display(new_display != "0");
891 }
892 catch(const logic_error& e)
893 {
894 cerr << e.what() << endl;
895 }
896 }
897 }
898 }
899
900 }
901
902
903 // OpenNN: Open Neural Networks Library.
904 // Copyright(C) 2005-2020 Artificial Intelligence Techniques, SL.
905 //
906 // This library is free software; you can redistribute it and/or
907 // modify it under the terms of the GNU Lesser General Public
908 // License as published by the Free Software Foundation; either
909 // version 2.1 of the License, or any later version.
910 //
911 // This library is distributed in the hope that it will be useful,
912 // but WITHOUT ANY WARRANTY; without even the implied warranty of
913 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
914 // Lesser General Public License for more details.
915
916 // You should have received a copy of the GNU Lesser General Public
917 // License along with this library; if not, write to the Free Software
918 // Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
919