1 // OpenNN: Open Neural Networks Library 2 // www.opennn.net 3 // 4 // M O D E L S E L E C T I O N C L A S S H E A D E R 5 // 6 // Artificial Intelligence Techniques SL 7 // artelnics@artelnics.com 8 9 #ifndef MODELSELECTION_H 10 #define MODELSELECTION_H 11 12 // System includes 13 14 #include <iostream> 15 #include <fstream> 16 #include <string> 17 #include <sstream> 18 #include <cmath> 19 #include <ctime> 20 21 // OpenNN includes 22 23 #include "config.h" 24 #include "training_strategy.h" 25 #include "growing_neurons.h" 26 #include "growing_inputs.h" 27 #include "pruning_inputs.h" 28 #include "genetic_algorithm.h" 29 30 namespace OpenNN 31 { 32 33 /// This class represents the concept of model selection[1] algorithm in OpenNN. 34 35 /// 36 /// It is used for finding a network architecture with maximum generalization capabilities. 37 /// 38 /// [1] Neural Designer "Model Selection Algorithms in Predictive Analytics." \ref https://www.neuraldesigner.com/blog/model-selection 39 40 class ModelSelection 41 { 42 43 public: 44 45 // Constructors 46 47 explicit ModelSelection(); 48 49 explicit ModelSelection(TrainingStrategy*); 50 51 // Destructor 52 53 virtual ~ModelSelection(); 54 55 /// Enumeration of all the available order selection algorithms. 56 57 enum NeuronsSelectionMethod{NO_NEURONS_SELECTION, GROWING_NEURONS}; 58 59 /// Enumeration of all the available inputs selection algorithms. 60 61 enum InputsSelectionMethod{NO_INPUTS_SELECTION, GROWING_INPUTS, PRUNING_INPUTS, GENETIC_ALGORITHM}; 62 63 /// This structure contains the results from the model selection process. 64 65 struct Results 66 { 67 /// Default constructor. 68 69 explicit Results(); 70 71 /// Pointer to a structure with the results from the growing neurons selection algorithm. 72 73 GrowingNeurons::GrowingNeuronsResults* growing_neurons_results_pointer = nullptr; 74 75 76 /// Pointer to a structure with the results from the growing inputs selection algorithm. 77 78 GrowingInputs::GrowingInputsResults* growing_inputs_results_pointer = nullptr; 79 80 81 /// Pointer to a structure with the results from the pruning inputs selection algorithm. 82 83 PruningInputs::PruningInputsResults* pruning_inputs_results_pointer = nullptr; 84 85 86 /// Pointer to a structure with the results from the genetic inputs selection algorithm. 87 88 GeneticAlgorithm::GeneticAlgorithmResults* genetic_algorithm_results_pointer = nullptr; 89 }; 90 91 92 // Get methods 93 94 TrainingStrategy* get_training_strategy_pointer() const; 95 bool has_training_strategy() const; 96 97 const NeuronsSelectionMethod& get_neurons_selection_method() const; 98 const InputsSelectionMethod& get_inputs_selection_method() const; 99 100 GrowingNeurons* get_growing_neurons_pointer(); 101 102 GrowingInputs* get_growing_inputs_pointer(); 103 PruningInputs* get_pruning_inputs_pointer(); 104 GeneticAlgorithm* get_genetic_algorithm_pointer(); 105 106 // Set methods 107 108 void set_default(); 109 110 void set_display(const bool&); 111 112 void set_training_strategy_pointer(TrainingStrategy*); 113 114 void set_neurons_selection_method(const NeuronsSelectionMethod&); 115 void set_neurons_selection_method(const string&); 116 117 void set_inputs_selection_method(const InputsSelectionMethod&); 118 void set_inputs_selection_method(const string&); 119 120 void set_approximation(const bool&); 121 122 // Model selection methods 123 124 void check() const; 125 126 Results perform_neurons_selection(); 127 128 Results perform_inputs_selection(); 129 130 Results perform_model_selection(); 131 132 // Serialization methods 133 134 void from_XML(const tinyxml2::XMLDocument&); 135 136 void write_XML(tinyxml2::XMLPrinter&) const; 137 138 string write_neurons_selection_method() const; 139 string write_inputs_selection_method() const; 140 141 void print() const; 142 void save(const string&) const; 143 void load(const string&); 144 145 private: 146 147 /// Pointer to a training strategy object. 148 149 TrainingStrategy* training_strategy_pointer = nullptr; 150 151 /// Growing order object to be used for order selection. 152 153 GrowingNeurons growing_neurons; 154 155 /// Growing inputs object to be used for inputs selection. 156 157 GrowingInputs growing_inputs; 158 159 /// Pruning inputs object to be used for inputs selection. 160 161 PruningInputs pruning_inputs; 162 163 /// Genetic algorithm object to be used for inputs selection. 164 165 GeneticAlgorithm genetic_algorithm; 166 167 /// Type of order selection algorithm. 168 169 NeuronsSelectionMethod neurons_selection_method; 170 171 /// Type of inputs selection algorithm. 172 173 InputsSelectionMethod inputs_selection_method; 174 175 /// Display messages to screen. 176 177 bool display = true; 178 }; 179 180 } 181 182 #endif 183