1 /**
2 * @file methods/ann/layer/vr_class_reward_impl.hpp
3 * @author Marcus Edel
4 *
5 * Implementation of the VRClassReward class, which implements the variance
6 * reduced classification reinforcement layer.
7 *
8 * mlpack is free software; you may redistribute it and/or modify it under the
9 * terms of the 3-clause BSD license. You should have received a copy of the
10 * 3-clause BSD license along with mlpack. If not, see
11 * http://www.opensource.org/licenses/BSD-3-Clause for more information.
12 */
13 #ifndef MLPACK_METHODS_ANN_LAYER_VR_CLASS_REWARD_IMPL_HPP
14 #define MLPACK_METHODS_ANN_LAYER_VR_CLASS_REWARD_IMPL_HPP
15
16 // In case it hasn't yet been included.
17 #include "vr_class_reward.hpp"
18
19 #include "../visitor/reward_set_visitor.hpp"
20
21 namespace mlpack {
22 namespace ann /** Artificial Neural Network. */ {
23
24 template<typename InputDataType, typename OutputDataType>
VRClassReward(const double scale,const bool sizeAverage)25 VRClassReward<InputDataType, OutputDataType>::VRClassReward(
26 const double scale,
27 const bool sizeAverage) :
28 scale(scale),
29 sizeAverage(sizeAverage),
30 reward(0)
31 {
32 // Nothing to do here.
33 }
34
35 template<typename InputDataType, typename OutputDataType>
36 template<typename InputType, typename TargetType>
Forward(const InputType & input,const TargetType & target)37 double VRClassReward<InputDataType, OutputDataType>::Forward(
38 const InputType& input, const TargetType& target)
39 {
40 double output = 0;
41 for (size_t i = 0; i < input.n_cols - 1; ++i)
42 {
43 size_t currentTarget = target(i) - 1;
44 Log::Assert(currentTarget < input.n_rows,
45 "Target class out of range.");
46
47 output -= input(currentTarget, i);
48 }
49
50 reward = 0;
51 arma::uword index = 0;
52
53 for (size_t i = 0; i < input.n_cols - 1; ++i)
54 {
55 input.unsafe_col(i).max(index);
56 reward = ((index + 1) == target(i)) * scale;
57 }
58
59 if (sizeAverage)
60 {
61 return output - reward / (input.n_cols - 1);
62 }
63
64 return output - reward;
65 }
66
67 template<typename InputDataType, typename OutputDataType>
68 template<typename InputType, typename TargetType, typename OutputType>
Backward(const InputType & input,const TargetType & target,OutputType & output)69 void VRClassReward<InputDataType, OutputDataType>::Backward(
70 const InputType& input,
71 const TargetType& target,
72 OutputType& output)
73 {
74 output = arma::zeros<OutputType>(input.n_rows, input.n_cols);
75 for (size_t i = 0; i < (input.n_cols - 1); ++i)
76 {
77 size_t currentTarget = target(i) - 1;
78 Log::Assert(currentTarget < input.n_rows,
79 "Target class out of range.");
80
81 output(currentTarget, i) = -1;
82 }
83
84 double vrReward = reward - input(0, 1);
85 if (sizeAverage)
86 {
87 vrReward /= input.n_cols - 1;
88 }
89
90 const double norm = sizeAverage ? 2.0 / (input.n_cols - 1) : 2.0;
91
92 output(0, 1) = norm * (input(0, 1) - reward);
93 boost::apply_visitor(RewardSetVisitor(vrReward), network.back());
94 }
95
96 template<typename InputDataType, typename OutputDataType>
97 template<typename Archive>
serialize(Archive &,const unsigned int)98 void VRClassReward<InputDataType, OutputDataType>::serialize(
99 Archive& /* ar */,
100 const unsigned int /* version */)
101 {
102 // Nothing to do here.
103 }
104
105 } // namespace ann
106 } // namespace mlpack
107
108 #endif
109