1 #ifndef CALLBACK_H_ 2 #define CALLBACK_H_ 3 4 #include <Eigen/Core> 5 #include "Config.h" 6 7 namespace MiniDNN 8 { 9 10 11 class Network; 12 13 /// 14 /// \defgroup Callbacks Callback Functions 15 /// 16 17 /// 18 /// \ingroup Callbacks 19 /// 20 /// The interface and default implementation of the callback function during 21 /// model fitting. The purpose of this class is to allow users printing some 22 /// messages in each epoch or mini-batch training, for example the time spent, 23 /// the loss function values, etc. 24 /// 25 /// This default implementation is a silent version of the callback function 26 /// that basically does nothing. See the VerboseCallback class for a verbose 27 /// version that prints the loss function value in each mini-batch. 28 /// 29 class Callback 30 { 31 protected: 32 typedef Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic> Matrix; 33 typedef Eigen::RowVectorXi IntegerVector; 34 35 public: 36 // Public members that will be set by the network during the training process 37 int m_nbatch; // Number of total batches 38 int m_batch_id; // The index for the current mini-batch (0, 1, ..., m_nbatch-1) 39 int m_nepoch; // Total number of epochs (one run on the whole data set) in the training process 40 int m_epoch_id; // The index for the current epoch (0, 1, ..., m_nepoch-1) 41 Callback()42 Callback() : 43 m_nbatch(0), m_batch_id(0), m_nepoch(0), m_epoch_id(0) 44 {} 45 ~Callback()46 virtual ~Callback() {} 47 48 // Before training a mini-batch pre_training_batch(const Network * net,const Matrix & x,const Matrix & y)49 virtual void pre_training_batch(const Network* net, const Matrix& x, 50 const Matrix& y) {} pre_training_batch(const Network * net,const Matrix & x,const IntegerVector & y)51 virtual void pre_training_batch(const Network* net, const Matrix& x, 52 const IntegerVector& y) {} 53 54 // After a mini-batch is trained post_training_batch(const Network * net,const Matrix & x,const Matrix & y)55 virtual void post_training_batch(const Network* net, const Matrix& x, 56 const Matrix& y) {} post_training_batch(const Network * net,const Matrix & x,const IntegerVector & y)57 virtual void post_training_batch(const Network* net, const Matrix& x, 58 const IntegerVector& y) {} 59 }; 60 61 62 } // namespace MiniDNN 63 64 65 #endif /* CALLBACK_H_ */ 66