1 //   OpenNN: Open Neural Networks Library
2 //   www.opennn.net
3 //
4 //   M O D E L   S E L E C T I O N   C L A S S   H E A D E R
5 //
6 //   Artificial Intelligence Techniques SL
7 //   artelnics@artelnics.com
8 
9 #ifndef MODELSELECTION_H
10 #define MODELSELECTION_H
11 
12 // System includes
13 
14 #include <iostream>
15 #include <fstream>
16 #include <string>
17 #include <sstream>
18 #include <cmath>
19 #include <ctime>
20 
21 // OpenNN includes
22 
23 #include "config.h"
24 #include "training_strategy.h"
25 #include "growing_neurons.h"
26 #include "growing_inputs.h"
27 #include "pruning_inputs.h"
28 #include "genetic_algorithm.h"
29 
30 namespace OpenNN
31 {
32 
33 /// This class represents the concept of model selection[1] algorithm in OpenNN.
34 
35 ///
36 /// It is used for finding a network architecture with maximum generalization capabilities.
37 ///
38 /// [1] Neural Designer "Model Selection Algorithms in Predictive Analytics." \ref https://www.neuraldesigner.com/blog/model-selection
39 
40 class ModelSelection
41 {
42 
43 public:
44 
45     // Constructors
46 
47     explicit ModelSelection();
48 
49     explicit ModelSelection(TrainingStrategy*);
50 
51     // Destructor
52 
53     virtual ~ModelSelection();
54 
55     /// Enumeration of all the available order selection algorithms.
56 
57     enum NeuronsSelectionMethod{NO_NEURONS_SELECTION, GROWING_NEURONS};
58 
59     /// Enumeration of all the available inputs selection algorithms.
60 
61     enum InputsSelectionMethod{NO_INPUTS_SELECTION, GROWING_INPUTS, PRUNING_INPUTS, GENETIC_ALGORITHM};
62 
63     /// This structure contains the results from the model selection process.
64 
65     struct Results
66     {
67         /// Default constructor.
68 
69         explicit Results();
70 
71         /// Pointer to a structure with the results from the growing neurons selection algorithm.
72 
73         GrowingNeurons::GrowingNeuronsResults* growing_neurons_results_pointer = nullptr;
74 
75 
76         /// Pointer to a structure with the results from the growing inputs selection algorithm.
77 
78         GrowingInputs::GrowingInputsResults* growing_inputs_results_pointer = nullptr;
79 
80 
81         /// Pointer to a structure with the results from the pruning inputs selection algorithm.
82 
83         PruningInputs::PruningInputsResults* pruning_inputs_results_pointer = nullptr;
84 
85 
86         /// Pointer to a structure with the results from the genetic inputs selection algorithm.
87 
88         GeneticAlgorithm::GeneticAlgorithmResults* genetic_algorithm_results_pointer = nullptr;
89     };
90 
91 
92     // Get methods
93 
94     TrainingStrategy* get_training_strategy_pointer() const;
95     bool has_training_strategy() const;
96 
97     const NeuronsSelectionMethod& get_neurons_selection_method() const;
98     const InputsSelectionMethod& get_inputs_selection_method() const;
99 
100     GrowingNeurons* get_growing_neurons_pointer();
101 
102     GrowingInputs* get_growing_inputs_pointer();
103     PruningInputs* get_pruning_inputs_pointer();
104     GeneticAlgorithm* get_genetic_algorithm_pointer();
105 
106     // Set methods
107 
108     void set_default();
109 
110     void set_display(const bool&);
111 
112     void set_training_strategy_pointer(TrainingStrategy*);
113 
114     void set_neurons_selection_method(const NeuronsSelectionMethod&);
115     void set_neurons_selection_method(const string&);
116 
117     void set_inputs_selection_method(const InputsSelectionMethod&);
118     void set_inputs_selection_method(const string&);
119 
120     void set_approximation(const bool&);
121 
122     // Model selection methods
123 
124     void check() const;
125 
126     Results perform_neurons_selection();
127 
128     Results perform_inputs_selection();
129 
130     Results perform_model_selection();
131 
132     // Serialization methods
133 
134     void from_XML(const tinyxml2::XMLDocument&);
135 
136     void write_XML(tinyxml2::XMLPrinter&) const;
137 
138     string write_neurons_selection_method() const;
139     string write_inputs_selection_method() const;
140 
141     void print() const;
142     void save(const string&) const;
143     void load(const string&);
144 
145 private:
146 
147     /// Pointer to a training strategy object.
148 
149     TrainingStrategy* training_strategy_pointer = nullptr;
150 
151     /// Growing order object to be used for order selection.
152 
153     GrowingNeurons growing_neurons;
154 
155     /// Growing inputs object to be used for inputs selection.
156 
157     GrowingInputs growing_inputs;
158 
159     /// Pruning inputs object to be used for inputs selection.
160 
161     PruningInputs pruning_inputs;
162 
163     /// Genetic algorithm object to be used for inputs selection.
164 
165     GeneticAlgorithm genetic_algorithm;
166 
167     /// Type of order selection algorithm.
168 
169     NeuronsSelectionMethod neurons_selection_method;
170 
171     /// Type of inputs selection algorithm.
172 
173     InputsSelectionMethod inputs_selection_method;
174 
175     /// Display messages to screen.
176 
177     bool display = true;
178 };
179 
180 }
181 
182 #endif
183