1 /** 2 * @file bbs_init.hpp 3 * @author Nanubala Gnana Sai 4 * 5 * The Bayesian Bootstrap (BBS) method of Weight Initialization. 6 * 7 * ensmallen is free software; you may redistribute it and/or modify it under 8 * the terms of the 3-clause BSD license. You should have received a copy of 9 * the 3-clause BSD license along with ensmallen. If not, see 10 * http://www.opensource.org/licenses/BSD-3-Clause for more information. 11 */ 12 #ifndef ENSMALLEN_MOEAD_BBS_HPP 13 #define ENSMALLEN_MOEAD_BBS_HPP 14 15 namespace ens { 16 17 /** 18 * The Bayesian Bootstrap method for initializing weights. Samples are randomly picked from uniform 19 * distribution followed by sorting and finding adjacent difference. This gives you a list of 20 * numbers which is guaranteed to sum up to 1. 21 * 22 * @code 23 * @article{rubin1981bayesian, 24 * title={The bayesian bootstrap}, 25 * author={Rubin, Donald B}, 26 * journal={The annals of statistics}, 27 * pages={130--134}, 28 * year={1981}, 29 * @endcode 30 * 31 */ 32 class BayesianBootstrap 33 { 34 public: 35 /** 36 * Constructor for Bayesian Bootstrap policy. 37 */ BayesianBootstrap()38 BayesianBootstrap() 39 { 40 /* Nothing to do. */ 41 } 42 43 /** 44 * Generate the reference direction matrix. 45 * 46 * @tparam MatType The type of the matrix used for constructing weights. 47 * @param numObjectives The dimensionality of objective space. 48 * @param numPoints The number of reference directions requested. 49 * @param epsilon Handle numerical stability after weight initialization. 50 */ 51 template<typename MatType> Generate(const size_t numObjectives,const size_t numPoints,const double epsilon)52 MatType Generate(const size_t numObjectives, 53 const size_t numPoints, 54 const double epsilon) 55 { 56 typedef typename MatType::elem_type ElemType; 57 typedef typename arma::Col<ElemType> VecType; 58 59 MatType weights(numObjectives, numPoints); 60 for (size_t pointIdx = 0; pointIdx < numPoints; ++pointIdx) 61 { 62 VecType referenceDirection(numObjectives + 1, arma::fill::randu); 63 referenceDirection(0) = 0; 64 referenceDirection(numObjectives) = 1; 65 referenceDirection = arma::sort(referenceDirection); 66 referenceDirection = arma::diff(referenceDirection); 67 weights.col(pointIdx) = std::move(referenceDirection) + epsilon; 68 } 69 70 return weights; 71 } 72 }; 73 74 } // namespace ens 75 76 #endif 77