1 // OpenNN: Open Neural Networks Library 2 // www.opennn.net 3 // 4 // C R O S S E N T R O P Y E R R O R C L A S S H E A D E R 5 // 6 // Artificial Intelligence Techniques SL 7 // artelnics@artelnics.com 8 9 #ifndef CROSSENTROPYERROR_H 10 #define CROSSENTROPYERROR_H 11 12 // System includes 13 14 #include <iostream> 15 #include <fstream> 16 #include <math.h> 17 18 // OpenNN includes 19 20 #include "loss_index.h" 21 #include "data_set.h" 22 #include "config.h" 23 24 namespace OpenNN 25 { 26 27 /// This class represents the cross entropy error term, used for predicting probabilities. 28 29 /// 30 /// This functional is used in classification problems. 31 32 class CrossEntropyError : public LossIndex 33 { 34 35 public: 36 37 // Constructors 38 39 explicit CrossEntropyError(); 40 41 explicit CrossEntropyError(NeuralNetwork*, DataSet*); 42 43 // Destructor 44 45 virtual ~CrossEntropyError(); 46 47 // Error methods 48 49 void calculate_error(const DataSet::Batch& batch, 50 const NeuralNetwork::ForwardPropagation& forward_propagation, 51 LossIndex::BackPropagation& back_propagation) const; 52 53 void calculate_binary_error(const DataSet::Batch& batch, 54 const NeuralNetwork::ForwardPropagation& forward_propagation, 55 LossIndex::BackPropagation& back_propagation) const; 56 57 void calculate_multiple_error(const DataSet::Batch& batch, 58 const NeuralNetwork::ForwardPropagation& forward_propagation, 59 LossIndex::BackPropagation& back_propagation) const; 60 61 // Gradient methods 62 63 void calculate_output_gradient(const DataSet::Batch& batch, 64 const NeuralNetwork::ForwardPropagation& forward_propagation, 65 BackPropagation& back_propagation) const; 66 67 void calculate_binary_output_gradient(const DataSet::Batch& batch, 68 const NeuralNetwork::ForwardPropagation& forward_propagation, 69 BackPropagation& back_propagation) const; 70 71 void calculate_multiple_output_gradient(const DataSet::Batch& batch, 72 const NeuralNetwork::ForwardPropagation& forward_propagation, 73 BackPropagation& back_propagation) const; 74 75 string get_error_type() const; 76 string get_error_type_text() const; 77 78 // Serialization methods 79 80 81 void from_XML(const tinyxml2::XMLDocument&); 82 83 void write_XML(tinyxml2::XMLPrinter&) const; 84 85 #ifdef OPENNN_CUDA 86 #include "../../opennn-cuda/opennn_cuda/cross_entropy_error_cuda.h" 87 #endif 88 89 #ifdef OPENNN_MKL 90 #include "../../opennn-mkl/opennn_mkl/cross_entropy_error_mkl.h" 91 #endif 92 }; 93 94 } 95 96 #endif 97 98 99 // OpenNN: Open Neural Networks Library. 100 // Copyright(C) 2005-2020 Artificial Intelligence Techniques, SL. 101 // 102 // This library is free software; you can redistribute it and/or 103 // modify it under the terms of the GNU Lesser General Public 104 // License as published by the Free Software Foundation; either 105 // version 2.1 of the License, or any later version. 106 // 107 // This library is distributed in the hope that it will be useful, 108 // but WITHOUT ANY WARRANTY; without even the implied warranty of 109 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 110 // Lesser General Public License for more details. 111 112 // You should have received a copy of the GNU Lesser General Public 113 // License along with this library; if not, write to the Free Software 114 // Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA 115