1 /**
2 * @file methods/ann/visitor/reward_set_visitor_impl.hpp
3 * @author Marcus Edel
4 *
5 * Implementation of the Reward() function layer abstraction.
6 *
7 * mlpack is free software; you may redistribute it and/or modify it under the
8 * terms of the 3-clause BSD license. You should have received a copy of the
9 * 3-clause BSD license along with mlpack. If not, see
10 * http://www.opensource.org/licenses/BSD-3-Clause for more information.
11 */
12 #ifndef MLPACK_METHODS_ANN_VISITOR_REWARD_SET_VISITOR_IMPL_HPP
13 #define MLPACK_METHODS_ANN_VISITOR_REWARD_SET_VISITOR_IMPL_HPP
14
15 // In case it hasn't been included yet.
16 #include "reward_set_visitor.hpp"
17
18 namespace mlpack {
19 namespace ann {
20
21 //! RewardSetVisitor visitor class.
RewardSetVisitor(const double reward)22 inline RewardSetVisitor::RewardSetVisitor(const double reward) : reward(reward)
23 {
24 /* Nothing to do here. */
25 }
26
27 template<typename LayerType>
operator ()(LayerType * layer) const28 inline void RewardSetVisitor::operator()(LayerType* layer) const
29 {
30 LayerReward(layer);
31 }
32
operator ()(MoreTypes layer) const33 inline void RewardSetVisitor::operator()(MoreTypes layer) const
34 {
35 layer.apply_visitor(*this);
36 }
37
38 template<typename T>
39 inline typename std::enable_if<
40 HasRewardCheck<T, double&(T::*)()>::value &&
41 HasModelCheck<T>::value, void>::type
LayerReward(T * layer) const42 RewardSetVisitor::LayerReward(T* layer) const
43 {
44 layer->Reward() = reward;
45
46 for (size_t i = 0; i < layer->Model().size(); ++i)
47 {
48 boost::apply_visitor(RewardSetVisitor(reward),
49 layer->Model()[i]);
50 }
51 }
52
53 template<typename T>
54 inline typename std::enable_if<
55 !HasRewardCheck<T, double&(T::*)()>::value &&
56 HasModelCheck<T>::value, void>::type
LayerReward(T * layer) const57 RewardSetVisitor::LayerReward(T* layer) const
58 {
59 for (size_t i = 0; i < layer->Model().size(); ++i)
60 {
61 boost::apply_visitor(RewardSetVisitor(reward),
62 layer->Model()[i]);
63 }
64 }
65
66 template<typename T>
67 inline typename std::enable_if<
68 HasRewardCheck<T, double&(T::*)()>::value &&
69 !HasModelCheck<T>::value, void>::type
LayerReward(T * layer) const70 RewardSetVisitor::LayerReward(T* layer) const
71 {
72 layer->Reward() = reward;
73 }
74
75 template<typename T>
76 inline typename std::enable_if<
77 !HasRewardCheck<T, double&(T::*)()>::value &&
78 !HasModelCheck<T>::value, void>::type
LayerReward(T *) const79 RewardSetVisitor::LayerReward(T* /* input */) const
80 {
81 /* Nothing to do here. */
82 }
83
84 } // namespace ann
85 } // namespace mlpack
86
87 #endif
88