1 // SPDX-License-Identifier: Apache-2.0 2 // 3 // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) 4 // Copyright 2008-2016 National ICT Australia (NICTA) 5 // 6 // Licensed under the Apache License, Version 2.0 (the "License"); 7 // you may not use this file except in compliance with the License. 8 // You may obtain a copy of the License at 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ------------------------------------------------------------------------ 17 18 19 //! \addtogroup gmm_diag 20 //! @{ 21 22 23 namespace gmm_priv 24 { 25 26 template<typename eT> 27 class gmm_diag 28 { 29 public: 30 31 arma_aligned const Mat<eT> means; 32 arma_aligned const Mat<eT> dcovs; 33 arma_aligned const Row<eT> hefts; 34 35 // 36 // 37 38 inline ~gmm_diag(); 39 inline gmm_diag(); 40 41 inline gmm_diag(const gmm_diag& x); 42 inline gmm_diag& operator=(const gmm_diag& x); 43 44 inline explicit gmm_diag(const gmm_full<eT>& x); 45 inline gmm_diag& operator=(const gmm_full<eT>& x); 46 47 inline gmm_diag(const uword in_n_dims, const uword in_n_gaus); 48 inline void reset(const uword in_n_dims, const uword in_n_gaus); 49 inline void reset(); 50 51 template<typename T1, typename T2, typename T3> 52 inline void set_params(const Base<eT,T1>& in_means, const Base<eT,T2>& in_dcovs, const Base<eT,T3>& in_hefts); 53 54 template<typename T1> inline void set_means(const Base<eT,T1>& in_means); 55 template<typename T1> inline void set_dcovs(const Base<eT,T1>& in_dcovs); 56 template<typename T1> inline void set_hefts(const Base<eT,T1>& in_hefts); 57 58 inline uword n_dims() const; 59 inline uword n_gaus() const; 60 61 inline bool load(const std::string name); 62 inline bool save(const std::string name) const; 63 64 inline Col<eT> generate() const; 65 inline Mat<eT> generate(const uword N) const; 66 67 template<typename T1> inline eT log_p(const T1& expr, const gmm_empty_arg& junk1 = gmm_empty_arg(), typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == true ))>::result* junk2 = nullptr) const; 68 template<typename T1> inline eT log_p(const T1& expr, const uword gaus_id, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == true ))>::result* junk2 = nullptr) const; 69 70 template<typename T1> inline Row<eT> log_p(const T1& expr, const gmm_empty_arg& junk1 = gmm_empty_arg(), typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == false))>::result* junk2 = nullptr) const; 71 template<typename T1> inline Row<eT> log_p(const T1& expr, const uword gaus_id, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == false))>::result* junk2 = nullptr) const; 72 73 template<typename T1> inline eT sum_log_p(const Base<eT,T1>& expr) const; 74 template<typename T1> inline eT sum_log_p(const Base<eT,T1>& expr, const uword gaus_id) const; 75 76 template<typename T1> inline eT avg_log_p(const Base<eT,T1>& expr) const; 77 template<typename T1> inline eT avg_log_p(const Base<eT,T1>& expr, const uword gaus_id) const; 78 79 template<typename T1> inline uword assign(const T1& expr, const gmm_dist_mode& dist, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == true ))>::result* junk = nullptr) const; 80 template<typename T1> inline urowvec assign(const T1& expr, const gmm_dist_mode& dist, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == false))>::result* junk = nullptr) const; 81 82 template<typename T1> inline urowvec raw_hist(const Base<eT,T1>& expr, const gmm_dist_mode& dist_mode) const; 83 template<typename T1> inline Row<eT> norm_hist(const Base<eT,T1>& expr, const gmm_dist_mode& dist_mode) const; 84 85 template<typename T1> 86 inline 87 bool 88 learn 89 ( 90 const Base<eT,T1>& data, 91 const uword n_gaus, 92 const gmm_dist_mode& dist_mode, 93 const gmm_seed_mode& seed_mode, 94 const uword km_iter, 95 const uword em_iter, 96 const eT var_floor, 97 const bool print_mode 98 ); 99 100 101 template<typename T1> 102 inline 103 bool 104 kmeans_wrapper 105 ( 106 Mat<eT>& user_means, 107 const Base<eT,T1>& data, 108 const uword n_gaus, 109 const gmm_seed_mode& seed_mode, 110 const uword km_iter, 111 const bool print_mode 112 ); 113 114 115 // 116 117 protected: 118 119 arma_aligned Mat<eT> inv_dcovs; 120 arma_aligned Row<eT> log_det_etc; 121 arma_aligned Row<eT> log_hefts; 122 arma_aligned Col<eT> mah_aux; 123 124 // 125 126 inline void init(const gmm_diag& x); 127 inline void init(const gmm_full<eT>& x); 128 129 inline void init(const uword in_n_dim, const uword in_n_gaus); 130 131 inline void init_constants(); 132 133 inline umat internal_gen_boundaries(const uword N) const; 134 135 inline eT internal_scalar_log_p(const eT* x ) const; 136 inline eT internal_scalar_log_p(const eT* x, const uword gaus_id) const; 137 138 inline Row<eT> internal_vec_log_p(const Mat<eT>& X ) const; 139 inline Row<eT> internal_vec_log_p(const Mat<eT>& X, const uword gaus_id) const; 140 141 inline eT internal_sum_log_p(const Mat<eT>& X ) const; 142 inline eT internal_sum_log_p(const Mat<eT>& X, const uword gaus_id) const; 143 144 inline eT internal_avg_log_p(const Mat<eT>& X ) const; 145 inline eT internal_avg_log_p(const Mat<eT>& X, const uword gaus_id) const; 146 147 inline uword internal_scalar_assign(const Mat<eT>& X, const gmm_dist_mode& dist_mode) const; 148 149 inline void internal_vec_assign(urowvec& out, const Mat<eT>& X, const gmm_dist_mode& dist_mode) const; 150 151 inline void internal_raw_hist(urowvec& hist, const Mat<eT>& X, const gmm_dist_mode& dist_mode) const; 152 153 // 154 155 template<uword dist_id> inline void generate_initial_means(const Mat<eT>& X, const gmm_seed_mode& seed); 156 157 template<uword dist_id> inline void generate_initial_params(const Mat<eT>& X, const eT var_floor); 158 159 template<uword dist_id> inline bool km_iterate(const Mat<eT>& X, const uword max_iter, const bool verbose, const char* signature); 160 161 // 162 163 inline bool em_iterate(const Mat<eT>& X, const uword max_iter, const eT var_floor, const bool verbose); 164 165 inline void em_update_params(const Mat<eT>& X, const umat& boundaries, field< Mat<eT> >& t_acc_means, field< Mat<eT> >& t_acc_dcovs, field< Col<eT> >& t_acc_norm_lhoods, field< Col<eT> >& t_gaus_log_lhoods, Col<eT>& t_progress_log_lhoods); 166 167 inline void em_generate_acc(const Mat<eT>& X, const uword start_index, const uword end_index, Mat<eT>& acc_means, Mat<eT>& acc_dcovs, Col<eT>& acc_norm_lhoods, Col<eT>& gaus_log_lhoods, eT& progress_log_lhood) const; 168 169 inline void em_fix_params(const eT var_floor); 170 }; 171 172 } 173 174 175 typedef gmm_priv::gmm_diag<double> gmm_diag; 176 typedef gmm_priv::gmm_diag<float> fgmm_diag; 177 178 179 //! @} 180