1 /**
2  * @file methods/ann/layer/vr_class_reward_impl.hpp
3  * @author Marcus Edel
4  *
5  * Implementation of the VRClassReward class, which implements the variance
6  * reduced classification reinforcement layer.
7  *
8  * mlpack is free software; you may redistribute it and/or modify it under the
9  * terms of the 3-clause BSD license.  You should have received a copy of the
10  * 3-clause BSD license along with mlpack.  If not, see
11  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
12  */
13 #ifndef MLPACK_METHODS_ANN_LAYER_VR_CLASS_REWARD_IMPL_HPP
14 #define MLPACK_METHODS_ANN_LAYER_VR_CLASS_REWARD_IMPL_HPP
15 
16 // In case it hasn't yet been included.
17 #include "vr_class_reward.hpp"
18 
19 #include "../visitor/reward_set_visitor.hpp"
20 
21 namespace mlpack {
22 namespace ann /** Artificial Neural Network. */ {
23 
24 template<typename InputDataType, typename OutputDataType>
VRClassReward(const double scale,const bool sizeAverage)25 VRClassReward<InputDataType, OutputDataType>::VRClassReward(
26     const double scale,
27     const bool sizeAverage) :
28     scale(scale),
29     sizeAverage(sizeAverage),
30     reward(0)
31 {
32   // Nothing to do here.
33 }
34 
35 template<typename InputDataType, typename OutputDataType>
36 template<typename InputType, typename TargetType>
Forward(const InputType & input,const TargetType & target)37 double VRClassReward<InputDataType, OutputDataType>::Forward(
38     const InputType& input, const TargetType& target)
39 {
40   double output = 0;
41   for (size_t i = 0; i < input.n_cols - 1; ++i)
42   {
43     size_t currentTarget = target(i) - 1;
44     Log::Assert(currentTarget < input.n_rows,
45         "Target class out of range.");
46 
47     output -= input(currentTarget, i);
48   }
49 
50   reward = 0;
51   arma::uword index = 0;
52 
53   for (size_t i = 0; i < input.n_cols - 1; ++i)
54   {
55     input.unsafe_col(i).max(index);
56     reward = ((index + 1) == target(i)) * scale;
57   }
58 
59   if (sizeAverage)
60   {
61     return output - reward / (input.n_cols - 1);
62   }
63 
64   return output - reward;
65 }
66 
67 template<typename InputDataType, typename OutputDataType>
68 template<typename InputType, typename TargetType, typename OutputType>
Backward(const InputType & input,const TargetType & target,OutputType & output)69 void VRClassReward<InputDataType, OutputDataType>::Backward(
70     const InputType& input,
71     const TargetType& target,
72     OutputType& output)
73 {
74   output = arma::zeros<OutputType>(input.n_rows, input.n_cols);
75   for (size_t i = 0; i < (input.n_cols - 1); ++i)
76   {
77     size_t currentTarget = target(i) - 1;
78     Log::Assert(currentTarget < input.n_rows,
79         "Target class out of range.");
80 
81     output(currentTarget, i) = -1;
82   }
83 
84   double vrReward = reward - input(0, 1);
85   if (sizeAverage)
86   {
87     vrReward /= input.n_cols - 1;
88   }
89 
90   const double norm = sizeAverage ? 2.0 / (input.n_cols - 1) : 2.0;
91 
92   output(0, 1) = norm * (input(0, 1) - reward);
93   boost::apply_visitor(RewardSetVisitor(vrReward), network.back());
94 }
95 
96 template<typename InputDataType, typename OutputDataType>
97 template<typename Archive>
serialize(Archive &,const unsigned int)98 void VRClassReward<InputDataType, OutputDataType>::serialize(
99     Archive& /* ar */,
100     const unsigned int /* version */)
101 {
102   // Nothing to do here.
103 }
104 
105 } // namespace ann
106 } // namespace mlpack
107 
108 #endif
109