1 // OpenNN: Open Neural Networks Library
2 // www.opennn.net
3 //
4 // M O D E L S E L E C T I O N C L A S S
5 //
6 // Artificial Intelligence Techniques SL
7 // artelnics@artelnics.com
8
9 #include "model_selection.h"
10
11 namespace OpenNN
12 {
13
14 /// Default constructor.
15
ModelSelection()16 ModelSelection::ModelSelection()
17 {
18 set_default();
19 }
20
21
22 /// Training strategy constructor.
23 /// @param new_training_strategy_pointer Pointer to a training strategy object.
24
ModelSelection(TrainingStrategy * new_training_strategy_pointer)25 ModelSelection::ModelSelection(TrainingStrategy* new_training_strategy_pointer)
26 {
27 training_strategy_pointer = new_training_strategy_pointer;
28
29 set_default();
30 }
31
32
33 /// Destructor.
34
~ModelSelection()35 ModelSelection::~ModelSelection()
36 {
37 }
38
39
40 /// Returns a pointer to the training strategy object.
41
get_training_strategy_pointer() const42 TrainingStrategy* ModelSelection::get_training_strategy_pointer() const
43 {
44 #ifdef __OPENNN_DEBUG__
45
46 if(!training_strategy_pointer)
47 {
48 ostringstream buffer;
49
50 buffer << "OpenNN Exception: ModelSelection class.\n"
51 << "TrainingStrategy* get_training_strategy_pointer() const method.\n"
52 << "Training strategy pointer is nullptr.\n";
53
54 throw logic_error(buffer.str());
55 }
56
57 #endif
58
59 return training_strategy_pointer;
60 }
61
62
63 /// Returns true if this model selection has a training strategy associated,
64 /// and false otherwise.
65
has_training_strategy() const66 bool ModelSelection::has_training_strategy() const
67 {
68 if(training_strategy_pointer)
69 {
70 return true;
71 }
72 else
73 {
74 return false;
75 }
76 }
77
78
79 /// Returns the type of algorithm for the order selection.
80
get_neurons_selection_method() const81 const ModelSelection::NeuronsSelectionMethod& ModelSelection::get_neurons_selection_method() const
82 {
83 return neurons_selection_method;
84 }
85
86
87 /// Returns the type of algorithm for the inputs selection.
88
get_inputs_selection_method() const89 const ModelSelection::InputsSelectionMethod& ModelSelection::get_inputs_selection_method() const
90 {
91 return inputs_selection_method;
92 }
93
94
95 /// Returns a pointer to the growing neurons selection algorithm.
96
get_growing_neurons_pointer()97 GrowingNeurons* ModelSelection::get_growing_neurons_pointer()
98 {
99 return &growing_neurons;
100 }
101
102
103 /// Returns a pointer to the growing inputs selection algorithm.
104
get_growing_inputs_pointer()105 GrowingInputs* ModelSelection::get_growing_inputs_pointer()
106 {
107 return &growing_inputs;
108 }
109
110
111 /// Returns a pointer to the pruning inputs selection algorithm.
112
get_pruning_inputs_pointer()113 PruningInputs* ModelSelection::get_pruning_inputs_pointer()
114 {
115 return &pruning_inputs;
116 }
117
118
119 /// Returns a pointer to the genetic inputs selection algorithm.
120
get_genetic_algorithm_pointer()121 GeneticAlgorithm* ModelSelection::get_genetic_algorithm_pointer()
122 {
123 return &genetic_algorithm;
124 }
125
126
127 /// Sets the members of the model selection object to their default values.
128
set_default()129 void ModelSelection::set_default()
130 {
131 set_neurons_selection_method(GROWING_NEURONS);
132
133 set_inputs_selection_method(GROWING_INPUTS);
134
135 display = true;
136 }
137
138
139 /// Sets a new display value.
140 /// If it is set to true messages from this class are to be displayed on the screen;
141 /// if it is set to false messages from this class are not to be displayed on the screen.
142 /// @param new_display Display value.
143
set_display(const bool & new_display)144 void ModelSelection::set_display(const bool& new_display)
145 {
146 display = new_display;
147
148 switch(inputs_selection_method)
149 {
150 case NO_INPUTS_SELECTION:
151 {
152 // do nothing
153
154 break;
155 }
156 case GROWING_INPUTS:
157 {
158 growing_inputs.set_display(new_display);
159
160 break;
161 }
162 case PRUNING_INPUTS:
163 {
164 pruning_inputs.set_display(new_display);
165
166 break;
167 }
168 case GENETIC_ALGORITHM:
169 {
170 genetic_algorithm.set_display(new_display);
171
172 break;
173 }
174 }
175
176 switch(neurons_selection_method)
177 {
178 case NO_NEURONS_SELECTION:
179 {
180 // do nothing
181
182 break;
183 }
184 case GROWING_NEURONS:
185 {
186 growing_neurons.set_display(new_display);
187
188 break;
189 }
190 }
191 }
192
193
194 /// Sets a new method for selecting the order which have more impact on the targets.
195 /// @param new_neurons_selection_method Method for selecting the order(NO_NEURONS_SELECTION, growing_neurons, GOLDEN_SECTION, SIMULATED_ANNEALING).
196
set_neurons_selection_method(const ModelSelection::NeuronsSelectionMethod & new_neurons_selection_method)197 void ModelSelection::set_neurons_selection_method(const ModelSelection::NeuronsSelectionMethod& new_neurons_selection_method)
198 {
199 neurons_selection_method = new_neurons_selection_method;
200 }
201
202
203 /// Sets a new order selection algorithm from a string.
204 /// @param new_neurons_selection_method String with the order selection type.
205
set_neurons_selection_method(const string & new_neurons_selection_method)206 void ModelSelection::set_neurons_selection_method(const string& new_neurons_selection_method)
207 {
208 if(new_neurons_selection_method == "NO_NEURONS_SELECTION")
209 {
210 set_neurons_selection_method(NO_NEURONS_SELECTION);
211 }
212 else if(new_neurons_selection_method == "GROWING_NEURONS")
213 {
214 set_neurons_selection_method(GROWING_NEURONS);
215 }
216 else
217 {
218 ostringstream buffer;
219
220 buffer << "OpenNN Exception: ModelSelection class.\n"
221 << "void set_neurons_selection_method(const string&) method.\n"
222 << "Unknown order selection type: " << new_neurons_selection_method << ".\n";
223
224 throw logic_error(buffer.str());
225 }
226 }
227
228
229 /// Sets a new method for selecting the inputs which have more impact on the targets.
230 /// @param new_inputs_selection_method Method for selecting the inputs(NO_INPUTS_SELECTION, GROWING_INPUTS, PRUNING_INPUTS, GENETIC_ALGORITHM).
231
set_inputs_selection_method(const ModelSelection::InputsSelectionMethod & new_inputs_selection_method)232 void ModelSelection::set_inputs_selection_method(const ModelSelection::InputsSelectionMethod& new_inputs_selection_method)
233 {
234 inputs_selection_method = new_inputs_selection_method;
235 }
236
237
238 /// Sets a new inputs selection algorithm from a string.
239 /// @param new_inputs_selection_method String with the inputs selection type.
240
set_inputs_selection_method(const string & new_inputs_selection_method)241 void ModelSelection::set_inputs_selection_method(const string& new_inputs_selection_method)
242 {
243 if(new_inputs_selection_method == "NO_INPUTS_SELECTION")
244 {
245 set_inputs_selection_method(NO_INPUTS_SELECTION);
246 }
247 else if(new_inputs_selection_method == "GROWING_INPUTS")
248 {
249 set_inputs_selection_method(GROWING_INPUTS);
250 }
251 else if(new_inputs_selection_method == "PRUNING_INPUTS")
252 {
253 set_inputs_selection_method(PRUNING_INPUTS);
254 }
255 else if(new_inputs_selection_method == "GENETIC_ALGORITHM")
256 {
257 set_inputs_selection_method(GENETIC_ALGORITHM);
258 }
259 else
260 {
261 ostringstream buffer;
262
263 buffer << "OpenNN Exception: ModelSelection class.\n"
264 << "void set_inputs_selection_method(const string&) method.\n"
265 << "Unknown inputs selection type: " << new_inputs_selection_method << ".\n";
266
267 throw logic_error(buffer.str());
268 }
269 }
270
271
272 /// Sets a new approximation method.
273 /// If it is set to true the problem will be taken as a approximation;
274 /// if it is set to false the problem will be taken as a classification.
275 /// @param new_approximation Approximation value.
276
set_approximation(const bool & new_approximation)277 void ModelSelection::set_approximation(const bool& new_approximation)
278 {
279 switch(inputs_selection_method)
280 {
281 case NO_INPUTS_SELECTION:
282 {
283 // do nothing
284
285 break;
286 }
287 case GROWING_INPUTS:
288 {
289 growing_inputs.set_approximation(new_approximation);
290
291 break;
292 }
293 case PRUNING_INPUTS:
294 {
295 pruning_inputs.set_approximation(new_approximation);
296
297 break;
298 }
299 case GENETIC_ALGORITHM:
300 {
301 genetic_algorithm.set_approximation(new_approximation);
302
303 break;
304 }
305 }
306 }
307
308
309 /// Sets a new training strategy pointer.
310 /// @param new_training_strategy_pointer Pointer to a training strategy object.
311
set_training_strategy_pointer(TrainingStrategy * new_training_strategy_pointer)312 void ModelSelection::set_training_strategy_pointer(TrainingStrategy* new_training_strategy_pointer)
313 {
314 training_strategy_pointer = new_training_strategy_pointer;
315
316 growing_neurons.set_training_strategy_pointer(new_training_strategy_pointer);
317
318 growing_inputs.set_training_strategy_pointer(new_training_strategy_pointer);
319 pruning_inputs.set_training_strategy_pointer(new_training_strategy_pointer);
320 genetic_algorithm.set_training_strategy_pointer(new_training_strategy_pointer);
321 }
322
323
324 /// Checks that the different pointers needed for performing the model selection are not nullptr.
325
check() const326 void ModelSelection::check() const
327 {
328
329 // Optimization algorithm
330
331 ostringstream buffer;
332
333 if(!training_strategy_pointer)
334 {
335 buffer << "OpenNN Exception: ModelSelection class.\n"
336 << "void check() const method.\n"
337 << "Pointer to training strategy is nullptr.\n";
338
339 throw logic_error(buffer.str());
340 }
341
342 // Loss index
343
344 const LossIndex* loss_index_pointer = training_strategy_pointer->get_loss_index_pointer();
345
346 if(!loss_index_pointer)
347 {
348 buffer << "OpenNN Exception: ModelSelection class.\n"
349 << "void check() const method.\n"
350 << "Pointer to loss index is nullptr.\n";
351
352 throw logic_error(buffer.str());
353 }
354
355 // Neural network
356
357 const NeuralNetwork* neural_network_pointer = loss_index_pointer->get_neural_network_pointer();
358
359 if(!neural_network_pointer)
360 {
361 buffer << "OpenNN Exception: ModelSelection class.\n"
362 << "void check() const method.\n"
363 << "Pointer to neural network is nullptr.\n";
364
365 throw logic_error(buffer.str());
366 }
367
368 if(neural_network_pointer->is_empty())
369 {
370 buffer << "OpenNN Exception: ModelSelection class.\n"
371 << "void check() const method.\n"
372 << "Multilayer Perceptron is empty.\n";
373
374 throw logic_error(buffer.str());
375 }
376
377 // Data set
378
379 const DataSet* data_set_pointer = loss_index_pointer->get_data_set_pointer();
380
381 if(!data_set_pointer)
382 {
383 buffer << "OpenNN Exception: ModelSelection class.\n"
384 << "void check() const method.\n"
385 << "Pointer to data set is nullptr.\n";
386
387 throw logic_error(buffer.str());
388 }
389
390 //
391
392 const Index selection_samples_number = data_set_pointer->get_selection_samples_number();
393
394 if(selection_samples_number == 0)
395 {
396 buffer << "OpenNN Exception: ModelSelection class.\n"
397 << "void check() const method.\n"
398 << "Number of selection samples is zero.\n";
399
400 throw logic_error(buffer.str());
401 }
402 }
403
404
405 /// Perform the order selection, returns a structure with the results of the order selection.
406 /// It also set the neural network of the training strategy pointer with the optimum parameters.
407
perform_neurons_selection()408 ModelSelection::Results ModelSelection::perform_neurons_selection()
409 {
410 Results results;
411
412 TrainingStrategy* ts = get_training_strategy_pointer();
413
414 switch(neurons_selection_method)
415 {
416 case NO_NEURONS_SELECTION:
417 {
418 break;
419 }
420 case GROWING_NEURONS:
421 {
422 growing_neurons.set_display(display);
423
424 growing_neurons.set_training_strategy_pointer(ts);
425
426 results.growing_neurons_results_pointer = growing_neurons.perform_neurons_selection();
427
428 break;
429 }
430 }
431
432 return results;
433 }
434
435
436 /// Perform the inputs selection, returns a structure with the results of the inputs selection.
437 /// It also set the neural network of the training strategy pointer with the optimum parameters.
438
perform_inputs_selection()439 ModelSelection::Results ModelSelection::perform_inputs_selection()
440 {
441 Results results;
442
443 TrainingStrategy* ts = get_training_strategy_pointer();
444
445 switch(inputs_selection_method)
446 {
447 case NO_INPUTS_SELECTION:
448 {
449 break;
450 }
451 case GROWING_INPUTS:
452 {
453 growing_inputs.set_display(display);
454
455 growing_inputs.set_training_strategy_pointer(ts);
456
457 results.growing_inputs_results_pointer = growing_inputs.perform_inputs_selection();
458
459 break;
460 }
461 case PRUNING_INPUTS:
462 {
463 pruning_inputs.set_display(display);
464
465 pruning_inputs.set_training_strategy_pointer(ts);
466
467 results.pruning_inputs_results_pointer = pruning_inputs.perform_inputs_selection();
468
469 break;
470 }
471 case GENETIC_ALGORITHM:
472 {
473 genetic_algorithm.set_display(display);
474
475 genetic_algorithm.set_training_strategy_pointer(ts);
476
477 results.genetic_algorithm_results_pointer = genetic_algorithm.perform_inputs_selection();
478
479 break;
480 }
481 }
482
483 return results;
484 }
485
486
487 /// Perform inputs selection and order selection.
488 /// @todo
489
perform_model_selection()490 ModelSelection::Results ModelSelection::perform_model_selection()
491 {
492 perform_inputs_selection();
493
494 return perform_neurons_selection();
495 }
496
497
498 /// Serializes the model selection object into a XML document of the TinyXML library without keep the DOM tree in memory.
499 /// See the OpenNN manual for more information about the format of this document.
500
write_XML(tinyxml2::XMLPrinter & file_stream) const501 void ModelSelection::write_XML(tinyxml2::XMLPrinter& file_stream) const
502 {
503 // Model selection
504
505 file_stream.OpenElement("ModelSelection");
506
507 // Neurons selection
508
509 file_stream.OpenElement("NeuronsSelection");
510
511 file_stream.OpenElement("NeuronsSelectionMethod");
512 file_stream.PushText(write_neurons_selection_method().c_str());
513 file_stream.CloseElement();
514
515 growing_neurons.write_XML(file_stream);
516
517 file_stream.CloseElement();
518
519 // Inputs selection
520
521 file_stream.OpenElement("InputsSelection");
522
523 file_stream.OpenElement("InputsSelectionMethod");
524 file_stream.PushText(write_inputs_selection_method().c_str());
525 file_stream.CloseElement();
526
527 growing_inputs.write_XML(file_stream);
528 pruning_inputs.write_XML(file_stream);
529 genetic_algorithm.write_XML(file_stream);
530
531 file_stream.CloseElement();
532
533 // Model selection (end tag)
534
535 file_stream.CloseElement();
536 }
537
538
539 /// Loads the members of this model selection object from a XML document.
540 /// @param document XML document of the TinyXML library.
541
from_XML(const tinyxml2::XMLDocument & document)542 void ModelSelection::from_XML(const tinyxml2::XMLDocument& document)
543 {
544 const tinyxml2::XMLElement* root_element = document.FirstChildElement("ModelSelection");
545
546 if(!root_element)
547 {
548 ostringstream buffer;
549
550 buffer << "OpenNN Exception: ModelSelection class.\n"
551 << "void from_XML(const tinyxml2::XMLDocument&) method.\n"
552 << "Model Selection element is nullptr.\n";
553
554 throw logic_error(buffer.str());
555 }
556
557 // Neurons Selection
558
559 const tinyxml2::XMLElement* neurons_selection_element = root_element->FirstChildElement("NeuronsSelection");
560
561 if(neurons_selection_element)
562 {
563 // Neurons selection method
564
565 const tinyxml2::XMLElement* neurons_selection_method_element = neurons_selection_element->FirstChildElement("NeuronsSelectionMethod");
566
567 set_neurons_selection_method(neurons_selection_method_element->GetText());
568
569 // Growing neurons
570
571 const tinyxml2::XMLElement* growing_neurons_element = neurons_selection_element->FirstChildElement("GrowingNeurons");
572
573 if(growing_neurons_element)
574 {
575 tinyxml2::XMLDocument growing_neurons_document;
576
577 tinyxml2::XMLElement* growing_neurons_element_copy = growing_neurons_document.NewElement("GrowingNeurons");
578
579 for(const tinyxml2::XMLNode* nodeFor=growing_neurons_element->FirstChild(); nodeFor; nodeFor=nodeFor->NextSibling())
580 {
581 tinyxml2::XMLNode* copy = nodeFor->DeepClone(&growing_neurons_document );
582 growing_neurons_element_copy->InsertEndChild(copy );
583 }
584
585 growing_neurons_document.InsertEndChild(growing_neurons_element_copy);
586
587 growing_neurons.from_XML(growing_neurons_document);
588 }
589
590 }
591
592 // Inputs Selection
593
594 {
595 const tinyxml2::XMLElement* inputs_selection_element = root_element->FirstChildElement("InputsSelection");
596
597 if(inputs_selection_element)
598 {
599 const tinyxml2::XMLElement* inputs_selection_method_element = inputs_selection_element->FirstChildElement("InputsSelectionMethod");
600
601 set_inputs_selection_method(inputs_selection_method_element->GetText());
602
603 // Growing inputs
604
605 const tinyxml2::XMLElement* growing_inputs_element = inputs_selection_element->FirstChildElement("GrowingInputs");
606
607 if(growing_inputs_element)
608 {
609 tinyxml2::XMLDocument growing_inputs_document;
610
611 tinyxml2::XMLElement* growing_inputs_element_copy = growing_inputs_document.NewElement("GrowingInputs");
612
613 for(const tinyxml2::XMLNode* nodeFor=growing_inputs_element->FirstChild(); nodeFor; nodeFor=nodeFor->NextSibling())
614 {
615 tinyxml2::XMLNode* copy = nodeFor->DeepClone(&growing_inputs_document );
616 growing_inputs_element_copy->InsertEndChild(copy );
617 }
618
619 growing_inputs_document.InsertEndChild(growing_inputs_element_copy);
620
621 growing_inputs.from_XML(growing_inputs_document);
622 }
623
624
625 // Pruning inputs
626
627 const tinyxml2::XMLElement* pruning_inputs_element = inputs_selection_element->FirstChildElement("PruningInputs");
628
629 if(pruning_inputs_element)
630 {
631 tinyxml2::XMLDocument pruning_inputs_document;
632
633 tinyxml2::XMLElement* pruning_inputs_element_copy = pruning_inputs_document.NewElement("PruningInputs");
634
635 for(const tinyxml2::XMLNode* nodeFor=pruning_inputs_element->FirstChild(); nodeFor; nodeFor=nodeFor->NextSibling())
636 {
637 tinyxml2::XMLNode* copy = nodeFor->DeepClone(&pruning_inputs_document );
638 pruning_inputs_element_copy->InsertEndChild(copy );
639 }
640
641 pruning_inputs_document.InsertEndChild(pruning_inputs_element_copy);
642
643 pruning_inputs.from_XML(pruning_inputs_document);
644 }
645
646
647 // Genetic algorithm
648
649 const tinyxml2::XMLElement* genetic_algorithm_element = inputs_selection_element->FirstChildElement("GeneticAlgorithm");
650
651 if(genetic_algorithm_element)
652 {
653 tinyxml2::XMLDocument genetic_algorithm_document;
654
655 tinyxml2::XMLElement* genetic_algorithm_element_copy = genetic_algorithm_document.NewElement("GeneticAlgorithm");
656
657 for(const tinyxml2::XMLNode* nodeFor=genetic_algorithm_element->FirstChild(); nodeFor; nodeFor=nodeFor->NextSibling())
658 {
659 tinyxml2::XMLNode* copy = nodeFor->DeepClone(&genetic_algorithm_document );
660 genetic_algorithm_element_copy->InsertEndChild(copy );
661 }
662
663 genetic_algorithm_document.InsertEndChild(genetic_algorithm_element_copy);
664
665 genetic_algorithm.from_XML(genetic_algorithm_document);
666 }
667
668 }
669 }
670 }
671
672
write_neurons_selection_method() const673 string ModelSelection::write_neurons_selection_method() const
674 {
675 switch (neurons_selection_method)
676 {
677 case NO_NEURONS_SELECTION:
678 return "NO_NEURONS_SELECTION";
679
680 case GROWING_NEURONS:
681 return "GROWING_NEURONS";
682 }
683 }
684
685
write_inputs_selection_method() const686 string ModelSelection::write_inputs_selection_method() const
687 {
688 switch (inputs_selection_method)
689 {
690 case NO_INPUTS_SELECTION:
691 return "NO_INPUTS_SELECTION";
692
693 case GROWING_INPUTS:
694 return "GROWING_INPUTS";
695
696 case PRUNING_INPUTS:
697 return "PRUNING_INPUTS";
698
699 case GENETIC_ALGORITHM:
700 return "GENETIC_ALGORITHM";
701 }
702 }
703
704
705 /// Prints to the screen the XML representation of this model selection object.
706
print() const707 void ModelSelection::print() const
708 {
709 // cout << to_string();
710 }
711
712
713 /// Saves the model selection members to a XML file.
714 /// @param file_name Name of model selection XML file.
715
save(const string & file_name) const716 void ModelSelection::save(const string& file_name) const
717 {
718 FILE *pFile;
719 // int err;
720
721 // err = fopen_s(&pFile, file_name.c_str(), "w");
722 pFile = fopen(file_name.c_str(), "w");
723
724 tinyxml2::XMLPrinter document(pFile);
725
726 write_XML(document);
727
728 fclose(pFile);
729 }
730
731
732 /// Loads the model selection members from a XML file.
733 /// @param file_name Name of model selection XML file.
734
load(const string & file_name)735 void ModelSelection::load(const string& file_name)
736 {
737 tinyxml2::XMLDocument document;
738
739 if(document.LoadFile(file_name.c_str()))
740 {
741 ostringstream buffer;
742
743 buffer << "OpenNN Exception: ModelSelection class.\n"
744 << "void load(const string&) method.\n"
745 << "Cannot load XML file " << file_name << ".\n";
746
747 throw logic_error(buffer.str());
748 }
749
750 from_XML(document);
751 }
752
753 /// Results constructor.
754
Results()755 ModelSelection::Results::Results()
756 {
757 growing_neurons_results_pointer = nullptr;
758
759 growing_inputs_results_pointer = nullptr;
760
761 pruning_inputs_results_pointer = nullptr;
762
763 genetic_algorithm_results_pointer = nullptr;
764 }
765
766 }
767
768 // OpenNN: Open Neural Networks Library.
769 // Copyright(C) 2005-2020 Artificial Intelligence Techniques, SL.
770 //
771 // This library is free software; you can redistribute it and/or
772 // modify it under the terms of the GNU Lesser General Public
773 // License as published by the Free Software Foundation; either
774 // version 2.1 of the License, or any later version.
775 //
776 // This library is distributed in the hope that it will be useful,
777 // but WITHOUT ANY WARRANTY; without even the implied warranty of
778 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
779 // Lesser General Public License for more details.
780
781 // You should have received a copy of the GNU Lesser General Public
782 // License along with this library; if not, write to the Free Software
783 // Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
784