1 #ifndef CALLBACK_H_
2 #define CALLBACK_H_
3 
4 #include <Eigen/Core>
5 #include "Config.h"
6 
7 namespace MiniDNN
8 {
9 
10 
11 class Network;
12 
13 ///
14 /// \defgroup Callbacks Callback Functions
15 ///
16 
17 ///
18 /// \ingroup Callbacks
19 ///
20 /// The interface and default implementation of the callback function during
21 /// model fitting. The purpose of this class is to allow users printing some
22 /// messages in each epoch or mini-batch training, for example the time spent,
23 /// the loss function values, etc.
24 ///
25 /// This default implementation is a silent version of the callback function
26 /// that basically does nothing. See the VerboseCallback class for a verbose
27 /// version that prints the loss function value in each mini-batch.
28 ///
29 class Callback
30 {
31     protected:
32         typedef Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic> Matrix;
33         typedef Eigen::RowVectorXi IntegerVector;
34 
35     public:
36         // Public members that will be set by the network during the training process
37         int m_nbatch;   // Number of total batches
38         int m_batch_id; // The index for the current mini-batch (0, 1, ..., m_nbatch-1)
39         int m_nepoch;   // Total number of epochs (one run on the whole data set) in the training process
40         int m_epoch_id; // The index for the current epoch (0, 1, ..., m_nepoch-1)
41 
Callback()42         Callback() :
43             m_nbatch(0), m_batch_id(0), m_nepoch(0), m_epoch_id(0)
44         {}
45 
~Callback()46         virtual ~Callback() {}
47 
48         // Before training a mini-batch
pre_training_batch(const Network * net,const Matrix & x,const Matrix & y)49         virtual void pre_training_batch(const Network* net, const Matrix& x,
50                                         const Matrix& y) {}
pre_training_batch(const Network * net,const Matrix & x,const IntegerVector & y)51         virtual void pre_training_batch(const Network* net, const Matrix& x,
52                                         const IntegerVector& y) {}
53 
54         // After a mini-batch is trained
post_training_batch(const Network * net,const Matrix & x,const Matrix & y)55         virtual void post_training_batch(const Network* net, const Matrix& x,
56                                          const Matrix& y) {}
post_training_batch(const Network * net,const Matrix & x,const IntegerVector & y)57         virtual void post_training_batch(const Network* net, const Matrix& x,
58                                          const IntegerVector& y) {}
59 };
60 
61 
62 } // namespace MiniDNN
63 
64 
65 #endif /* CALLBACK_H_ */
66