1 //   OpenNN: Open Neural Networks Library
2 //   www.opennn.net
3 //
4 //   L E A R N I N G   R A T E   A L G O R I T H M   C L A S S
5 //
6 //   Artificial Intelligence Techniques SL
7 //   artelnics@artelnics.com
8 
9 #include "learning_rate_algorithm.h"
10 
11 namespace OpenNN
12 {
13 
14 /// Default constructor.
15 /// It creates a learning rate algorithm object not associated to any loss index object.
16 /// It also initializes the class members to their default values.
17 
LearningRateAlgorithm()18 LearningRateAlgorithm::LearningRateAlgorithm()
19     : loss_index_pointer(nullptr)
20 {
21     set_default();
22 }
23 
24 
25 /// Destructor.
26 /// It creates a learning rate algorithm associated to a loss index.
27 /// It also initializes the class members to their default values.
28 /// @param new_loss_index_pointer Pointer to a loss index object.
29 
LearningRateAlgorithm(LossIndex * new_loss_index_pointer)30 LearningRateAlgorithm::LearningRateAlgorithm(LossIndex* new_loss_index_pointer)
31     : loss_index_pointer(new_loss_index_pointer)
32 {
33     set_default();
34 }
35 
36 
37 /// Destructor.
38 
~LearningRateAlgorithm()39 LearningRateAlgorithm::~LearningRateAlgorithm()
40 {
41     delete non_blocking_thread_pool;
42     delete thread_pool_device;
43 }
44 
45 
46 /// Returns a pointer to the loss index object
47 /// to which the learning rate algorithm is associated.
48 /// If the loss index pointer is nullptr, this method throws an exception.
49 
get_loss_index_pointer() const50 LossIndex* LearningRateAlgorithm::get_loss_index_pointer() const
51 {
52 #ifdef __OPENNN_DEBUG__
53 
54     if(!loss_index_pointer)
55     {
56         ostringstream buffer;
57 
58         buffer << "OpenNN Exception: LearningRateAlgorithm class.\n"
59                << "LossIndex* get_loss_index_pointer() const method.\n"
60                << "Loss index pointer is nullptr.\n";
61 
62         throw logic_error(buffer.str());
63     }
64 
65 #endif
66 
67     return loss_index_pointer;
68 }
69 
70 
71 /// Returns true if this learning rate algorithm has an associated loss index,
72 /// and false otherwise.
73 
has_loss_index() const74 bool LearningRateAlgorithm::has_loss_index() const
75 {
76     if(loss_index_pointer)
77     {
78         return true;
79     }
80     else
81     {
82         return false;
83     }
84 }
85 
86 
87 /// Returns the learning rate method used for training.
88 
get_learning_rate_method() const89 const LearningRateAlgorithm::LearningRateMethod& LearningRateAlgorithm::get_learning_rate_method() const
90 {
91     return learning_rate_method;
92 }
93 
94 
95 /// Returns a string with the name of the learning rate method to be used.
96 
write_learning_rate_method() const97 string LearningRateAlgorithm::write_learning_rate_method() const
98 {
99     switch(learning_rate_method)
100     {
101     case GoldenSection:
102         return "GoldenSection";
103 
104     case BrentMethod:
105         return "BrentMethod";
106     }
107 
108     return string();
109 }
110 
111 
get_learning_rate_tolerance() const112 const type& LearningRateAlgorithm::get_learning_rate_tolerance() const
113 {
114     return learning_rate_tolerance;
115 }
116 
117 
118 /// Returns true if messages from this class can be displayed on the screen, or false if messages from
119 /// this class can't be displayed on the screen.
120 
get_display() const121 const bool& LearningRateAlgorithm::get_display() const
122 {
123     return display;
124 }
125 
126 
127 /// Sets the loss index pointer to nullptr.
128 /// It also sets the rest of members to their default values.
129 
set()130 void LearningRateAlgorithm::set()
131 {
132     loss_index_pointer = nullptr;
133 
134     set_default();
135 }
136 
137 
138 /// Sets a new loss index pointer.
139 /// It also sets the rest of members to their default values.
140 /// @param new_loss_index_pointer Pointer to a loss index object.
141 
set(LossIndex * new_loss_index_pointer)142 void LearningRateAlgorithm::set(LossIndex* new_loss_index_pointer)
143 {
144     loss_index_pointer = new_loss_index_pointer;
145 
146     set_default();
147 }
148 
149 
150 /// Sets the members of the learning rate algorithm to their default values.
151 
set_default()152 void LearningRateAlgorithm::set_default()
153 {
154     delete non_blocking_thread_pool;
155     delete thread_pool_device;
156 
157     const int n = omp_get_max_threads();
158     non_blocking_thread_pool = new NonBlockingThreadPool(n);
159     thread_pool_device = new ThreadPoolDevice(non_blocking_thread_pool, n);
160 
161     // TRAINING OPERATORS
162 
163     learning_rate_method = BrentMethod;
164 
165     // TRAINING PARAMETERS
166 
167     learning_rate_tolerance = static_cast<type>(1.0e-3);
168     loss_tolerance = static_cast<type>(1.0e-3);
169 }
170 
171 
172 /// Sets a pointer to a loss index object to be associated to the optimization algorithm.
173 /// @param new_loss_index_pointer Pointer to a loss index object.
174 
set_loss_index_pointer(LossIndex * new_loss_index_pointer)175 void LearningRateAlgorithm::set_loss_index_pointer(LossIndex* new_loss_index_pointer)
176 {
177     loss_index_pointer = new_loss_index_pointer;
178 }
179 
180 
set_threads_number(const int & new_threads_number)181 void LearningRateAlgorithm::set_threads_number(const int& new_threads_number)
182 {
183     if(non_blocking_thread_pool != nullptr) delete this->non_blocking_thread_pool;
184     if(thread_pool_device != nullptr) delete this->thread_pool_device;
185 
186     non_blocking_thread_pool = new NonBlockingThreadPool(new_threads_number);
187     thread_pool_device = new ThreadPoolDevice(non_blocking_thread_pool, new_threads_number);
188 }
189 
190 
191 /// Sets a new learning rate method to be used for training.
192 /// @param new_learning_rate_method Learning rate method.
193 
set_learning_rate_method(const LearningRateAlgorithm::LearningRateMethod & new_learning_rate_method)194 void LearningRateAlgorithm::set_learning_rate_method(
195         const LearningRateAlgorithm::LearningRateMethod& new_learning_rate_method)
196 {
197     learning_rate_method = new_learning_rate_method;
198 }
199 
200 
201 /// Sets the method for obtaining the learning rate from a string with the name of the method.
202 /// @param new_learning_rate_method Name of learning rate method("Fixed", "GoldenSection", "BrentMethod").
203 
set_learning_rate_method(const string & new_learning_rate_method)204 void LearningRateAlgorithm::set_learning_rate_method(const string& new_learning_rate_method)
205 {
206     if(new_learning_rate_method == "GoldenSection")
207     {
208         learning_rate_method = GoldenSection;
209     }
210     else if(new_learning_rate_method == "BrentMethod")
211     {
212         learning_rate_method = BrentMethod;
213     }
214     else
215     {
216         ostringstream buffer;
217 
218         buffer << "OpenNN Exception: LearningRateAlgorithm class.\n"
219                << "void set_method(const string&) method.\n"
220                << "Unknown learning rate method: " << new_learning_rate_method << ".\n";
221 
222         throw logic_error(buffer.str());
223     }
224 }
225 
226 
227 /// Sets a new tolerance value to be used in line minimization.
228 /// @param new_learning_rate_tolerance Tolerance value in line minimization.
229 
set_learning_rate_tolerance(const type & new_learning_rate_tolerance)230 void LearningRateAlgorithm::set_learning_rate_tolerance(const type& new_learning_rate_tolerance)
231 {
232 #ifdef __OPENNN_DEBUG__
233 
234     if(new_learning_rate_tolerance <= static_cast<type>(0.0))
235     {
236         ostringstream buffer;
237 
238         buffer << "OpenNN Exception: LearningRateAlgorithm class.\n"
239                << "void set_learning_rate_tolerance(const type&) method.\n"
240                << "Tolerance must be greater than 0.\n";
241 
242         throw logic_error(buffer.str());
243     }
244 
245 #endif
246 
247     // Set loss tolerance
248 
249     learning_rate_tolerance = new_learning_rate_tolerance;
250 }
251 
252 
253 /// Sets a new display value.
254 /// If it is set to true messages from this class are to be displayed on the screen;
255 /// if it is set to false messages from this class are not to be displayed on the screen.
256 /// @param new_display Display value.
257 
set_display(const bool & new_display)258 void LearningRateAlgorithm::set_display(const bool& new_display)
259 {
260     display = new_display;
261 }
262 
263 
264 /// Returns a vector with two elements:
265 ///(i) the learning rate calculated by means of the corresponding algorithm, and
266 ///(ii) the loss for that learning rate.
267 /// @param loss Initial Performance value.
268 /// @param training_direction Initial training direction.
269 /// @param initial_learning_rate Initial learning rate to start the algorithm.
270 
calculate_directional_point(const DataSet::Batch & batch,NeuralNetwork::ForwardPropagation & forward_propagation,LossIndex::BackPropagation & back_propagation,OptimizationAlgorithm::OptimizationData & optimization_data) const271 pair<type,type> LearningRateAlgorithm::calculate_directional_point(
272     const DataSet::Batch& batch,
273     NeuralNetwork::ForwardPropagation& forward_propagation,
274     LossIndex::BackPropagation& back_propagation,
275     OptimizationAlgorithm::OptimizationData& optimization_data) const
276 {
277     const NeuralNetwork* neural_network_pointer = loss_index_pointer->get_neural_network_pointer();
278 
279 #ifdef __OPENNN_DEBUG__
280 
281     if(loss_index_pointer == nullptr)
282     {
283         ostringstream buffer;
284 
285         buffer << "OpenNN Error: LearningRateAlgorithm class.\n"
286                << "pair<type, 1> calculate_directional_point() const method.\n"
287                << "Pointer to loss index is nullptr.\n";
288 
289         throw logic_error(buffer.str());
290     }
291 
292     if(neural_network_pointer == nullptr)
293     {
294         ostringstream buffer;
295 
296         buffer << "OpenNN Error: LearningRateAlgorithm class.\n"
297                << "Tensor<type, 1> calculate_directional_point() const method.\n"
298                << "Pointer to neural network is nullptr.\n";
299 
300         throw logic_error(buffer.str());
301     }
302 
303     if(thread_pool_device == nullptr)
304     {
305         ostringstream buffer;
306 
307         buffer << "OpenNN Error: LearningRateAlgorithm class.\n"
308                << "pair<type, 1> calculate_directional_point() const method.\n"
309                << "Pointer to thread pool device is nullptr.\n";
310 
311         throw logic_error(buffer.str());
312     }
313 
314 #endif
315 
316     ostringstream buffer;
317 
318     const type regularization_weight = loss_index_pointer->get_regularization_weight();
319 
320 
321     // Bracket minimum
322 
323     Triplet triplet = calculate_bracketing_triplet(batch,
324                                                    forward_propagation,
325                                                    back_propagation,
326                                                    optimization_data);
327 
328     try
329     {
330         triplet.check();
331     }
332     catch(const logic_error& error)
333     {
334         //cout << "Triplet bracketing" << endl;
335 
336         //cout << error.what() << endl;
337 
338         return triplet.minimum();
339     }
340 
341     pair<type, type> V;
342 
343     // Reduce the interval
344 
345     while(fabs(triplet.A.first-triplet.B.first) > learning_rate_tolerance
346       ||  fabs(triplet.A.second-triplet.B.second) > loss_tolerance)
347     {
348         try
349         {
350             switch(learning_rate_method)
351             {
352                 case GoldenSection: V.first = calculate_golden_section_learning_rate(triplet); break;
353 
354                 case BrentMethod: V.first = calculate_Brent_method_learning_rate(triplet); break;
355             }
356         }
357         catch(const logic_error& error)
358         {
359             cout << "Learning rate" << endl;
360 
361             //cout << error.what() << endl;
362 
363             return triplet.minimum();
364         }
365 
366         // Calculate loss for V
367 
368         optimization_data.potential_parameters.device(*thread_pool_device)
369                 = optimization_data.parameters + optimization_data.training_direction*V.first;
370 
371         neural_network_pointer->forward_propagate(batch, optimization_data.potential_parameters, forward_propagation);
372 
373         loss_index_pointer->calculate_error(batch, forward_propagation, back_propagation);
374 
375         const type regularization = loss_index_pointer->calculate_regularization(optimization_data.potential_parameters);
376 
377         V.second = back_propagation.error + regularization_weight*regularization;
378 
379         // Update points
380 
381         if(V.first <= triplet.U.first)
382         {
383             if(V.second >= triplet.U.second)
384             {
385                 triplet.A = V;
386             }
387             else if(V.second <= triplet.U.second)
388             {
389                 triplet.B = triplet.U;
390                 triplet.U = V;
391             }
392         }
393         else if(V.first >= triplet.U.first)
394         {
395             if(V.second >= triplet.U.second)
396             {
397                 triplet.B = V;
398             }
399             else if(V.second <= triplet.U.second)
400             {
401                 triplet.A = triplet.U;
402                 triplet.U = V;
403             }
404         }
405         else
406         {
407             buffer << "OpenNN Exception: LearningRateAlgorithm class.\n"
408                    << "Tensor<type, 1> calculate_Brent_method_directional_point() const method.\n"
409                    << "Unknown set:\n"
410                    << "A = (" << triplet.A.first << "," << triplet.A.second << ")\n"
411                    << "B = (" << triplet.B.first << "," << triplet.B.second << ")\n"
412                    << "U = (" << triplet.U.first << "," << triplet.U.second << ")\n"
413                    << "V = (" << V.first << "," << V.second << ")\n";
414 
415             throw logic_error(buffer.str());
416         }
417 
418         // Check triplet
419 
420         try
421         {
422             triplet.check();
423         }
424         catch(const logic_error& error)
425         {
426             //cout << "Triplet reduction" << endl;
427 
428             //cout << error.what() << endl;
429 
430             return triplet.minimum();
431         }
432     }
433 
434     return triplet.U;
435 
436 
437 /*
438     catch(range_error& e) // Interval is of length 0
439     {
440         cout << "Interval is of length 0" << endl;
441         cerr << e.what() << endl;
442 
443         pair<type, type> A;
444         A.first = 0;
445         A.second = loss;
446 
447         return A;
448     }
449     catch(const logic_error& e)
450     {
451         cout << "Other exception" << endl;
452         cerr << e.what() << endl;
453 
454         pair<type, type> X;
455         X.first = optimization_data.initial_learning_rate;
456 
457         optimization_data.potential_parameters.device(*thread_pool_device)
458                 = optimization_data.parameters + optimization_data.training_direction*X.first;
459 
460         neural_network_pointer->forward_propagate(batch, optimization_data.potential_parameters, forward_propagation);
461 
462         loss_index_pointer->calculate_error(batch, forward_propagation, back_propagation);
463 
464         const type regularization = loss_index_pointer->calculate_regularization(optimization_data.potential_parameters);
465 
466         X.second = back_propagation.error + regularization_weight*regularization;
467 
468         if(X.second > loss)
469         {
470             X.first = 0;
471             X.second = 0;
472         }
473 
474         return X;
475     }
476 */
477 
478 
479 
480     return pair<type,type>();
481 }
482 
483 
484 /// Returns bracketing triplet.
485 /// This algorithm is used by line minimization algorithms.
486 /// @param loss Initial Performance value.
487 /// @param training_direction Initial training direction.
488 /// @param initial_learning_rate Initial learning rate to start the algorithm.
489 
calculate_bracketing_triplet(const DataSet::Batch & batch,NeuralNetwork::ForwardPropagation & forward_propagation,LossIndex::BackPropagation & back_propagation,OptimizationAlgorithm::OptimizationData & optimization_data) const490 LearningRateAlgorithm::Triplet LearningRateAlgorithm::calculate_bracketing_triplet(
491     const DataSet::Batch& batch,
492     NeuralNetwork::ForwardPropagation& forward_propagation,
493     LossIndex::BackPropagation& back_propagation,
494     OptimizationAlgorithm::OptimizationData& optimization_data) const
495 {
496     const NeuralNetwork* neural_network_pointer = loss_index_pointer->get_neural_network_pointer();
497 
498 #ifdef __OPENNN_DEBUG__
499 
500     ostringstream buffer;
501 
502     if(loss_index_pointer == nullptr)
503     {
504         buffer << "OpenNN Error: LearningRateAlgorithm class.\n"
505                << "Triplet calculate_bracketing_triplet() const method.\n"
506                << "Pointer to loss index is nullptr.\n";
507 
508         throw logic_error(buffer.str());
509     }
510 
511     if(neural_network_pointer == nullptr)
512     {
513         buffer << "OpenNN Error: LearningRateAlgorithm class.\n"
514                << "Triplet calculate_bracketing_triplet() const method.\n"
515                << "Pointer to neural network is nullptr.\n";
516 
517         throw logic_error(buffer.str());
518     }
519 
520     if(thread_pool_device == nullptr)
521     {
522         buffer << "OpenNN Error: LearningRateAlgorithm class.\n"
523                << "Triplet calculate_bracketing_triplet() const method.\n"
524                << "Pointer to thread pool device is nullptr.\n";
525 
526         throw logic_error(buffer.str());
527     }
528 
529     if(is_zero(optimization_data.training_direction))
530     {
531         buffer << "OpenNN Error: LearningRateAlgorithm class.\n"
532                << "Triplet calculate_bracketing_triplet() const method.\n"
533                << "Training direction is zero.\n";
534 
535         throw logic_error(buffer.str());
536     }
537 
538     if(optimization_data.initial_learning_rate < numeric_limits<type>::min())
539     {
540         buffer << "OpenNN Error: LearningRateAlgorithm class.\n"
541                << "Triplet calculate_bracketing_triplet() const method.\n"
542                << "Initial learning rate is zero.\n";
543 
544         throw logic_error(buffer.str());
545     }
546 
547 #endif
548 
549     const type loss = back_propagation.loss;
550 
551     const type regularization_weight = loss_index_pointer->get_regularization_weight();
552 
553     Triplet triplet;
554 
555     // Left point
556 
557     triplet.A.first = 0;
558     triplet.A.second = loss;
559 
560     // Right point
561 
562     Index count = 0;
563 
564     do
565     {
566         count++;
567 
568         triplet.B.first = optimization_data.initial_learning_rate*count;
569 
570         optimization_data.potential_parameters.device(*thread_pool_device)
571                 = optimization_data.parameters + optimization_data.training_direction*triplet.B.first;
572 
573         neural_network_pointer->forward_propagate(batch, optimization_data.potential_parameters, forward_propagation);
574 
575         loss_index_pointer->calculate_error(batch, forward_propagation, back_propagation);
576 
577         const type regularization = loss_index_pointer->calculate_regularization(optimization_data.potential_parameters);
578 
579         triplet.B.second = back_propagation.error + regularization_weight*regularization;
580 
581     } while(abs(triplet.A.second - triplet.B.second) < numeric_limits<type>::min());
582 
583 
584     if(triplet.A.second > triplet.B.second)
585     {
586         triplet.U = triplet.B;
587 
588         triplet.B.first *= golden_ratio;
589 
590         optimization_data.potential_parameters.device(*thread_pool_device)
591                 = optimization_data.parameters + optimization_data.training_direction*triplet.B.first;
592 
593         neural_network_pointer->forward_propagate(batch, optimization_data.potential_parameters, forward_propagation);
594 
595         loss_index_pointer->calculate_error(batch, forward_propagation, back_propagation);
596 
597         const type regularization = loss_index_pointer->calculate_regularization(optimization_data.potential_parameters);
598 
599         triplet.B.second = back_propagation.error + regularization_weight*regularization;
600 
601         while(triplet.U.second > triplet.B.second)
602         {
603             triplet.A = triplet.U;
604             triplet.U = triplet.B;
605 
606             triplet.B.first *= golden_ratio;
607 
608             optimization_data.potential_parameters.device(*thread_pool_device) = optimization_data.parameters + optimization_data.training_direction*triplet.B.first;
609 
610             neural_network_pointer->forward_propagate(batch, optimization_data.potential_parameters, forward_propagation);
611 
612             loss_index_pointer->calculate_error(batch, forward_propagation, back_propagation);
613 
614             const type regularization = loss_index_pointer->calculate_regularization(optimization_data.potential_parameters);
615 
616             triplet.B.second = back_propagation.error + regularization_weight*regularization;
617         }
618     }
619     else if(triplet.A.second < triplet.B.second)
620     {
621         triplet.U.first = triplet.A.first + (triplet.B.first - triplet.A.first)*static_cast<type>(0.382);
622 
623         optimization_data.potential_parameters.device(*thread_pool_device) = optimization_data.parameters + optimization_data.training_direction*triplet.U.first;
624 
625         neural_network_pointer->forward_propagate(batch, optimization_data.potential_parameters, forward_propagation);
626 
627         loss_index_pointer->calculate_error(batch, forward_propagation, back_propagation);
628 
629         const type regularization = loss_index_pointer->calculate_regularization(optimization_data.potential_parameters);
630 
631         triplet.U.second = back_propagation.error + regularization_weight*regularization;
632 
633         while(triplet.A.second < triplet.U.second)
634         {
635             triplet.B = triplet.U;
636 
637             triplet.U.first = triplet.A.first + (triplet.B.first-triplet.A.first)*static_cast<type>(0.382);
638 
639             optimization_data.potential_parameters.device(*thread_pool_device)
640                     = optimization_data.parameters + optimization_data.training_direction*triplet.U.first;
641 
642             neural_network_pointer->forward_propagate(batch, optimization_data.potential_parameters, forward_propagation);
643 
644             loss_index_pointer->calculate_error(batch, forward_propagation, back_propagation);
645 
646             const type regularization = loss_index_pointer->calculate_regularization(optimization_data.potential_parameters);
647 
648             triplet.U.second = back_propagation.error + regularization_weight*regularization;
649 
650             if(triplet.U.first - triplet.A.first <= learning_rate_tolerance)
651             {
652                 triplet.U = triplet.A;
653                 triplet.B = triplet.A;
654 
655                 return triplet;
656             }
657         }
658     }
659 
660     return triplet;
661 }
662 
663 /// Calculates the golden section point within a minimum interval defined by three points.
664 /// @param triplet Triplet containing a minimum.
665 
calculate_golden_section_learning_rate(const Triplet & triplet) const666 type LearningRateAlgorithm::calculate_golden_section_learning_rate(const Triplet& triplet) const
667 {
668     type learning_rate;
669 
670     const type middle = triplet.A.first + static_cast<type>(0.5)*(triplet.B.first - triplet.A.first);
671 
672     if(triplet.U.first < middle)
673     {
674         learning_rate = triplet.A.first + static_cast<type>(0.618)*(triplet.B.first - triplet.A.first);
675     }
676     else
677     {
678         learning_rate = triplet.A.first + static_cast<type>(0.382)*(triplet.B.first - triplet.A.first);
679     }
680 
681 #ifdef __OPENNN_DEBUG__
682 
683     if(learning_rate < triplet.A.first)
684     {
685         ostringstream buffer;
686 
687         buffer << "OpenNN Error: LearningRateAlgorithm class.\n"
688                << "type calculate_golden_section_learning_rate(const Triplet&) const method.\n"
689                << "Learning rate(" << learning_rate << ") is less than left point("
690                << triplet.A.first << ").\n";
691 
692         throw logic_error(buffer.str());
693     }
694 
695     if(learning_rate > triplet.B.first)
696     {
697         ostringstream buffer;
698 
699         buffer << "OpenNN Error: LearningRateAlgorithm class.\n"
700                << "type calculate_golden_section_learning_rate(const Triplet&) const method.\n"
701                << "Learning rate(" << learning_rate << ") is greater than right point("
702                << triplet.B.first << ").\n";
703 
704         throw logic_error(buffer.str());
705     }
706 
707 #endif
708 
709     return learning_rate;
710 }
711 
712 
713 /// Returns the minimimal learning rate of a parabola defined by three directional points.
714 /// @param triplet Triplet containing a minimum.
715 
calculate_Brent_method_learning_rate(const Triplet & triplet) const716 type LearningRateAlgorithm::calculate_Brent_method_learning_rate(const Triplet& triplet) const
717 {
718     const type a = triplet.A.first;
719     const type u = triplet.U.first;
720     const type b = triplet.B.first;
721 
722     const type fa = triplet.A.second;
723     const type fu = triplet.U.second;
724     const type fb = triplet.B.second;
725 
726 
727     type numerator = (u-a)*(u-a)*(fu-fb) - (u-b)*(u-b)*(fu-fa);
728 
729     type denominator = (u-a)*(fu-fb) - (u-b)*(fu-fa);
730 
731     return u - 0.5*numerator/denominator;
732 
733 /*
734     const type c = -(triplet.A.second*(triplet.U.first-triplet.B.first)
735                      + triplet.U.second*(triplet.B.first-triplet.A.first)
736                      + triplet.B.second*(triplet.A.first-triplet.U.first))
737             /((triplet.A.first-triplet.U.first)*(triplet.U.first-triplet.B.first)*(triplet.B.first-triplet.A.first));
738 
739     if(abs(c) < numeric_limits<type>::min())
740     {
741         ostringstream buffer;
742 
743         buffer << "OpenNN Exception: LearningRateAlgorithm class.\n"
744                << "type calculate_Brent_method_learning_rate(const Triplet&) const method.\n"
745                << "Parabola cannot be constructed.\n";
746 
747         throw logic_error(buffer.str());
748     }
749     else if(c < 0)
750     {
751         ostringstream buffer;
752 
753         buffer << "OpenNN Exception: LearningRateAlgorithm class.\n"
754                << "type calculate_Brent_method_learning_rate(const Triplet&) const method.\n"
755                << "Parabola does not have a minimum but a maximum.\n";
756 
757         throw logic_error(buffer.str());
758     }
759 
760     const type b = (triplet.A.second*(triplet.U.first*triplet.U.first-triplet.B.first*triplet.B.first)
761                   + triplet.U.second*(triplet.B.first*triplet.B.first-triplet.A.first*triplet.A.first)
762                   + triplet.B.second*(triplet.A.first*triplet.A.first-triplet.U.first*triplet.U.first))
763                   /((triplet.A.first-triplet.U.first)*(triplet.U.first-triplet.B.first)*(triplet.B.first-triplet.A.first));
764 
765     const type Brent_method_learning_rate = -b/(static_cast<type>(2.0)*c);
766 
767     if(Brent_method_learning_rate < triplet.A.first || Brent_method_learning_rate > triplet.B.first)
768     {
769         ostringstream buffer;
770 
771         buffer << "OpenNN Exception: LearningRateAlgorithm class.\n"
772                << "type calculate_Brent_method_learning_rate() const method.\n"
773                << "Brent method learning rate is not inside interval.\n"
774                << "Interval:(" << triplet.A.first << "," << triplet.B.first << ")\n"
775                << "Brent method learning rate: " << Brent_method_learning_rate << endl;
776 
777 
778         throw logic_error(buffer.str());
779     }
780 
781     return Brent_method_learning_rate;
782 */
783 }
784 
785 
786 /// Serializes the learning rate algorithm object into a XML document of the TinyXML library
787 /// without keep the DOM tree in memory.
788 /// See the OpenNN manual for more information about the format of this document.
789 
write_XML(tinyxml2::XMLPrinter & file_stream) const790 void LearningRateAlgorithm::write_XML(tinyxml2::XMLPrinter& file_stream) const
791 {
792     ostringstream buffer;
793 
794     // Learning rate algorithm
795 
796     file_stream.OpenElement("LearningRateAlgorithm");
797 
798     // Learning rate method
799 
800     file_stream.OpenElement("LearningRateMethod");
801 
802     file_stream.PushText(write_learning_rate_method().c_str());
803 
804     file_stream.CloseElement();
805 
806     // Learning rate tolerance
807 
808     file_stream.OpenElement("LearningRateTolerance");
809 
810     buffer.str("");
811     buffer << learning_rate_tolerance;
812 
813     file_stream.PushText(buffer.str().c_str());
814 
815     file_stream.CloseElement();
816 
817     // Learning rate algorithm (end tag)
818 
819     file_stream.CloseElement();
820 }
821 
822 
823 /// Loads a learning rate algorithm object from a XML-type file.
824 /// Please mind about the file format, wich is specified in the manual.
825 /// @param document TinyXML document with the learning rate algorithm members.
826 
from_XML(const tinyxml2::XMLDocument & document)827 void LearningRateAlgorithm::from_XML(const tinyxml2::XMLDocument& document)
828 {
829     const tinyxml2::XMLElement* root_element = document.FirstChildElement("LearningRateAlgorithm");
830 
831     if(!root_element)
832     {
833         ostringstream buffer;
834 
835         buffer << "OpenNN Exception: LearningRateAlgorithm class.\n"
836                << "void from_XML(const tinyxml2::XMLDocument&) method.\n"
837                << "Learning rate algorithm element is nullptr.\n";
838 
839         throw logic_error(buffer.str());
840     }
841 
842     // Learning rate method
843     {
844         const tinyxml2::XMLElement* element = root_element->FirstChildElement("LearningRateMethod");
845 
846         if(element)
847         {
848             string new_learning_rate_method = element->GetText();
849 
850             try
851             {
852                 set_learning_rate_method(new_learning_rate_method);
853             }
854             catch(const logic_error& e)
855             {
856                 cerr << e.what() << endl;
857             }
858         }
859     }
860 
861     // Learning rate tolerance
862     {
863         const tinyxml2::XMLElement* element = root_element->FirstChildElement("LearningRateTolerance");
864 
865         if(element)
866         {
867             const type new_learning_rate_tolerance = static_cast<type>(atof(element->GetText()));
868 
869             try
870             {
871                 set_learning_rate_tolerance(new_learning_rate_tolerance);
872             }
873             catch(const logic_error& e)
874             {
875                 cerr << e.what() << endl;
876             }
877         }
878     }
879 
880     // Display warnings
881     {
882         const tinyxml2::XMLElement* element = root_element->FirstChildElement("Display");
883 
884         if(element)
885         {
886             const string new_display = element->GetText();
887 
888             try
889             {
890                 set_display(new_display != "0");
891             }
892             catch(const logic_error& e)
893             {
894                 cerr << e.what() << endl;
895             }
896         }
897     }
898 }
899 
900 }
901 
902 
903 // OpenNN: Open Neural Networks Library.
904 // Copyright(C) 2005-2020 Artificial Intelligence Techniques, SL.
905 //
906 // This library is free software; you can redistribute it and/or
907 // modify it under the terms of the GNU Lesser General Public
908 // License as published by the Free Software Foundation; either
909 // version 2.1 of the License, or any later version.
910 //
911 // This library is distributed in the hope that it will be useful,
912 // but WITHOUT ANY WARRANTY; without even the implied warranty of
913 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
914 // Lesser General Public License for more details.
915 
916 // You should have received a copy of the GNU Lesser General Public
917 // License along with this library; if not, write to the Free Software
918 // Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
919