1 /**
2  * @file bbs_init.hpp
3  * @author Nanubala Gnana Sai
4  *
5  * The Bayesian Bootstrap (BBS) method of Weight Initialization.
6  *
7  * ensmallen is free software; you may redistribute it and/or modify it under
8  * the terms of the 3-clause BSD license.  You should have received a copy of
9  * the 3-clause BSD license along with ensmallen.  If not, see
10  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
11  */
12 #ifndef ENSMALLEN_MOEAD_BBS_HPP
13 #define ENSMALLEN_MOEAD_BBS_HPP
14 
15 namespace ens {
16 
17 /**
18  * The Bayesian Bootstrap method for initializing weights. Samples are randomly picked from uniform
19  * distribution followed by sorting and finding adjacent difference. This gives you a list of
20  * numbers which is guaranteed to sum up to 1.
21  *
22  * @code
23  * @article{rubin1981bayesian,
24  *   title={The bayesian bootstrap},
25  *   author={Rubin, Donald B},
26  *   journal={The annals of statistics},
27  *   pages={130--134},
28  *   year={1981},
29  * @endcode
30  *
31  */
32 class BayesianBootstrap
33 {
34  public:
35   /**
36    * Constructor for Bayesian Bootstrap policy.
37    */
BayesianBootstrap()38   BayesianBootstrap()
39   {
40     /* Nothing to do. */
41   }
42 
43   /**
44    * Generate the reference direction matrix.
45    *
46    * @tparam MatType The type of the matrix used for constructing weights.
47    * @param numObjectives The dimensionality of objective space.
48    * @param numPoints The number of reference directions requested.
49    * @param epsilon Handle numerical stability after weight initialization.
50    */
51   template<typename MatType>
Generate(const size_t numObjectives,const size_t numPoints,const double epsilon)52   MatType Generate(const size_t numObjectives,
53                    const size_t numPoints,
54                    const double epsilon)
55   {
56       typedef typename MatType::elem_type ElemType;
57       typedef typename arma::Col<ElemType> VecType;
58 
59       MatType weights(numObjectives, numPoints);
60       for (size_t pointIdx = 0; pointIdx < numPoints; ++pointIdx)
61       {
62         VecType referenceDirection(numObjectives + 1, arma::fill::randu);
63         referenceDirection(0) = 0;
64         referenceDirection(numObjectives) = 1;
65         referenceDirection = arma::sort(referenceDirection);
66         referenceDirection = arma::diff(referenceDirection);
67         weights.col(pointIdx) = std::move(referenceDirection) + epsilon;
68       }
69 
70       return weights;
71   }
72 };
73 
74 } // namespace ens
75 
76 #endif
77