1 /**
2  * @file methods/ann/visitor/reward_set_visitor_impl.hpp
3  * @author Marcus Edel
4  *
5  * Implementation of the Reward() function layer abstraction.
6  *
7  * mlpack is free software; you may redistribute it and/or modify it under the
8  * terms of the 3-clause BSD license.  You should have received a copy of the
9  * 3-clause BSD license along with mlpack.  If not, see
10  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
11  */
12 #ifndef MLPACK_METHODS_ANN_VISITOR_REWARD_SET_VISITOR_IMPL_HPP
13 #define MLPACK_METHODS_ANN_VISITOR_REWARD_SET_VISITOR_IMPL_HPP
14 
15 // In case it hasn't been included yet.
16 #include "reward_set_visitor.hpp"
17 
18 namespace mlpack {
19 namespace ann {
20 
21 //! RewardSetVisitor visitor class.
RewardSetVisitor(const double reward)22 inline RewardSetVisitor::RewardSetVisitor(const double reward) : reward(reward)
23 {
24   /* Nothing to do here. */
25 }
26 
27 template<typename LayerType>
operator ()(LayerType * layer) const28 inline void RewardSetVisitor::operator()(LayerType* layer) const
29 {
30   LayerReward(layer);
31 }
32 
operator ()(MoreTypes layer) const33 inline void RewardSetVisitor::operator()(MoreTypes layer) const
34 {
35   layer.apply_visitor(*this);
36 }
37 
38 template<typename T>
39 inline typename std::enable_if<
40     HasRewardCheck<T, double&(T::*)()>::value &&
41     HasModelCheck<T>::value, void>::type
LayerReward(T * layer) const42 RewardSetVisitor::LayerReward(T* layer) const
43 {
44   layer->Reward() = reward;
45 
46   for (size_t i = 0; i < layer->Model().size(); ++i)
47   {
48     boost::apply_visitor(RewardSetVisitor(reward),
49         layer->Model()[i]);
50   }
51 }
52 
53 template<typename T>
54 inline typename std::enable_if<
55     !HasRewardCheck<T, double&(T::*)()>::value &&
56     HasModelCheck<T>::value, void>::type
LayerReward(T * layer) const57 RewardSetVisitor::LayerReward(T* layer) const
58 {
59   for (size_t i = 0; i < layer->Model().size(); ++i)
60   {
61     boost::apply_visitor(RewardSetVisitor(reward),
62         layer->Model()[i]);
63   }
64 }
65 
66 template<typename T>
67 inline typename std::enable_if<
68     HasRewardCheck<T, double&(T::*)()>::value &&
69     !HasModelCheck<T>::value, void>::type
LayerReward(T * layer) const70 RewardSetVisitor::LayerReward(T* layer) const
71 {
72   layer->Reward() = reward;
73 }
74 
75 template<typename T>
76 inline typename std::enable_if<
77     !HasRewardCheck<T, double&(T::*)()>::value &&
78     !HasModelCheck<T>::value, void>::type
LayerReward(T *) const79 RewardSetVisitor::LayerReward(T* /* input */) const
80 {
81   /* Nothing to do here. */
82 }
83 
84 } // namespace ann
85 } // namespace mlpack
86 
87 #endif
88