1 // SPDX-License-Identifier: Apache-2.0
2 //
3 // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
4 // Copyright 2008-2016 National ICT Australia (NICTA)
5 //
6 // Licensed under the Apache License, Version 2.0 (the "License");
7 // you may not use this file except in compliance with the License.
8 // You may obtain a copy of the License at
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // ------------------------------------------------------------------------
17 
18 
19 //! \addtogroup gmm_diag
20 //! @{
21 
22 
23 namespace gmm_priv
24 {
25 
26 
27 template<typename eT>
28 inline
~gmm_diag()29 gmm_diag<eT>::~gmm_diag()
30   {
31   arma_extra_debug_sigprint_this(this);
32 
33   arma_type_check(( (is_same_type<eT,float>::value == false) && (is_same_type<eT,double>::value == false) ));
34   }
35 
36 
37 
38 template<typename eT>
39 inline
gmm_diag()40 gmm_diag<eT>::gmm_diag()
41   {
42   arma_extra_debug_sigprint_this(this);
43   }
44 
45 
46 
47 template<typename eT>
48 inline
gmm_diag(const gmm_diag<eT> & x)49 gmm_diag<eT>::gmm_diag(const gmm_diag<eT>& x)
50   {
51   arma_extra_debug_sigprint_this(this);
52 
53   init(x);
54   }
55 
56 
57 
58 template<typename eT>
59 inline
60 gmm_diag<eT>&
operator =(const gmm_diag<eT> & x)61 gmm_diag<eT>::operator=(const gmm_diag<eT>& x)
62   {
63   arma_extra_debug_sigprint();
64 
65   init(x);
66 
67   return *this;
68   }
69 
70 
71 
72 template<typename eT>
73 inline
gmm_diag(const gmm_full<eT> & x)74 gmm_diag<eT>::gmm_diag(const gmm_full<eT>& x)
75   {
76   arma_extra_debug_sigprint_this(this);
77 
78   init(x);
79   }
80 
81 
82 
83 template<typename eT>
84 inline
85 gmm_diag<eT>&
operator =(const gmm_full<eT> & x)86 gmm_diag<eT>::operator=(const gmm_full<eT>& x)
87   {
88   arma_extra_debug_sigprint();
89 
90   init(x);
91 
92   return *this;
93   }
94 
95 
96 
97 template<typename eT>
98 inline
gmm_diag(const uword in_n_dims,const uword in_n_gaus)99 gmm_diag<eT>::gmm_diag(const uword in_n_dims, const uword in_n_gaus)
100   {
101   arma_extra_debug_sigprint_this(this);
102 
103   init(in_n_dims, in_n_gaus);
104   }
105 
106 
107 
108 template<typename eT>
109 inline
110 void
reset()111 gmm_diag<eT>::reset()
112   {
113   arma_extra_debug_sigprint();
114 
115   init(0, 0);
116   }
117 
118 
119 
120 template<typename eT>
121 inline
122 void
reset(const uword in_n_dims,const uword in_n_gaus)123 gmm_diag<eT>::reset(const uword in_n_dims, const uword in_n_gaus)
124   {
125   arma_extra_debug_sigprint();
126 
127   init(in_n_dims, in_n_gaus);
128   }
129 
130 
131 
132 template<typename eT>
133 template<typename T1, typename T2, typename T3>
134 inline
135 void
set_params(const Base<eT,T1> & in_means_expr,const Base<eT,T2> & in_dcovs_expr,const Base<eT,T3> & in_hefts_expr)136 gmm_diag<eT>::set_params(const Base<eT,T1>& in_means_expr, const Base<eT,T2>& in_dcovs_expr, const Base<eT,T3>& in_hefts_expr)
137   {
138   arma_extra_debug_sigprint();
139 
140   const unwrap<T1> tmp1(in_means_expr.get_ref());
141   const unwrap<T2> tmp2(in_dcovs_expr.get_ref());
142   const unwrap<T3> tmp3(in_hefts_expr.get_ref());
143 
144   const Mat<eT>& in_means = tmp1.M;
145   const Mat<eT>& in_dcovs = tmp2.M;
146   const Mat<eT>& in_hefts = tmp3.M;
147 
148   arma_debug_check
149     (
150     (arma::size(in_means) != arma::size(in_dcovs)) || (in_hefts.n_cols != in_means.n_cols) || (in_hefts.n_rows != 1),
151     "gmm_diag::set_params(): given parameters have inconsistent and/or wrong sizes"
152     );
153 
154   arma_debug_check( (in_means.is_finite() == false), "gmm_diag::set_params(): given means have non-finite values" );
155   arma_debug_check( (in_dcovs.is_finite() == false), "gmm_diag::set_params(): given dcovs have non-finite values" );
156   arma_debug_check( (in_hefts.is_finite() == false), "gmm_diag::set_params(): given hefts have non-finite values" );
157 
158   arma_debug_check( (any(vectorise(in_dcovs) <= eT(0))), "gmm_diag::set_params(): given dcovs have negative or zero values" );
159   arma_debug_check( (any(vectorise(in_hefts) <  eT(0))), "gmm_diag::set_params(): given hefts have negative values"         );
160 
161   const eT s = accu(in_hefts);
162 
163   arma_debug_check( ((s < (eT(1) - eT(0.001))) || (s > (eT(1) + eT(0.001)))), "gmm_diag::set_params(): sum of given hefts is not 1" );
164 
165   access::rw(means) = in_means;
166   access::rw(dcovs) = in_dcovs;
167   access::rw(hefts) = in_hefts;
168 
169   init_constants();
170   }
171 
172 
173 
174 template<typename eT>
175 template<typename T1>
176 inline
177 void
set_means(const Base<eT,T1> & in_means_expr)178 gmm_diag<eT>::set_means(const Base<eT,T1>& in_means_expr)
179   {
180   arma_extra_debug_sigprint();
181 
182   const unwrap<T1> tmp(in_means_expr.get_ref());
183 
184   const Mat<eT>& in_means = tmp.M;
185 
186   arma_debug_check( (arma::size(in_means) != arma::size(means)), "gmm_diag::set_means(): given means have incompatible size" );
187   arma_debug_check( (in_means.is_finite() == false),             "gmm_diag::set_means(): given means have non-finite values" );
188 
189   access::rw(means) = in_means;
190   }
191 
192 
193 
194 template<typename eT>
195 template<typename T1>
196 inline
197 void
set_dcovs(const Base<eT,T1> & in_dcovs_expr)198 gmm_diag<eT>::set_dcovs(const Base<eT,T1>& in_dcovs_expr)
199   {
200   arma_extra_debug_sigprint();
201 
202   const unwrap<T1> tmp(in_dcovs_expr.get_ref());
203 
204   const Mat<eT>& in_dcovs = tmp.M;
205 
206   arma_debug_check( (arma::size(in_dcovs) != arma::size(dcovs)), "gmm_diag::set_dcovs(): given dcovs have incompatible size"       );
207   arma_debug_check( (in_dcovs.is_finite() == false),             "gmm_diag::set_dcovs(): given dcovs have non-finite values"       );
208   arma_debug_check( (any(vectorise(in_dcovs) <= eT(0))),         "gmm_diag::set_dcovs(): given dcovs have negative or zero values" );
209 
210   access::rw(dcovs) = in_dcovs;
211 
212   init_constants();
213   }
214 
215 
216 
217 template<typename eT>
218 template<typename T1>
219 inline
220 void
set_hefts(const Base<eT,T1> & in_hefts_expr)221 gmm_diag<eT>::set_hefts(const Base<eT,T1>& in_hefts_expr)
222   {
223   arma_extra_debug_sigprint();
224 
225   const unwrap<T1> tmp(in_hefts_expr.get_ref());
226 
227   const Mat<eT>& in_hefts = tmp.M;
228 
229   arma_debug_check( (arma::size(in_hefts) != arma::size(hefts)), "gmm_diag::set_hefts(): given hefts have incompatible size" );
230   arma_debug_check( (in_hefts.is_finite() == false),             "gmm_diag::set_hefts(): given hefts have non-finite values" );
231   arma_debug_check( (any(vectorise(in_hefts) <  eT(0))),         "gmm_diag::set_hefts(): given hefts have negative values"   );
232 
233   const eT s = accu(in_hefts);
234 
235   arma_debug_check( ((s < (eT(1) - eT(0.001))) || (s > (eT(1) + eT(0.001)))), "gmm_diag::set_hefts(): sum of given hefts is not 1" );
236 
237   // make sure all hefts are positive and non-zero
238 
239   const eT* in_hefts_mem = in_hefts.memptr();
240         eT*    hefts_mem = access::rw(hefts).memptr();
241 
242   for(uword i=0; i < hefts.n_elem; ++i)
243     {
244     hefts_mem[i] = (std::max)( in_hefts_mem[i], std::numeric_limits<eT>::min() );
245     }
246 
247   access::rw(hefts) /= accu(hefts);
248 
249   log_hefts = log(hefts);
250   }
251 
252 
253 
254 template<typename eT>
255 inline
256 uword
n_dims() const257 gmm_diag<eT>::n_dims() const
258   {
259   return means.n_rows;
260   }
261 
262 
263 
264 template<typename eT>
265 inline
266 uword
n_gaus() const267 gmm_diag<eT>::n_gaus() const
268   {
269   return means.n_cols;
270   }
271 
272 
273 
274 template<typename eT>
275 inline
276 bool
load(const std::string name)277 gmm_diag<eT>::load(const std::string name)
278   {
279   arma_extra_debug_sigprint();
280 
281   Cube<eT> Q;
282 
283   bool status = Q.load(name, arma_binary);
284 
285   if( (status == false) || (Q.n_slices != 2) )
286     {
287     reset();
288     arma_debug_warn_level(3, "gmm_diag::load(): problem with loading or incompatible format");
289     return false;
290     }
291 
292   if( (Q.n_rows < 2) || (Q.n_cols < 1) )
293     {
294     reset();
295     return true;
296     }
297 
298   access::rw(hefts) = Q.slice(0).row(0);
299   access::rw(means) = Q.slice(0).submat(1, 0, Q.n_rows-1, Q.n_cols-1);
300   access::rw(dcovs) = Q.slice(1).submat(1, 0, Q.n_rows-1, Q.n_cols-1);
301 
302   init_constants();
303 
304   return true;
305   }
306 
307 
308 
309 template<typename eT>
310 inline
311 bool
save(const std::string name) const312 gmm_diag<eT>::save(const std::string name) const
313   {
314   arma_extra_debug_sigprint();
315 
316   Cube<eT> Q(means.n_rows + 1, means.n_cols, 2, arma_nozeros_indicator());
317 
318   if(Q.n_elem > 0)
319     {
320     Q.slice(0).row(0) = hefts;
321     Q.slice(1).row(0).zeros();  // reserved for future use
322 
323     Q.slice(0).submat(1, 0, arma::size(means)) = means;
324     Q.slice(1).submat(1, 0, arma::size(dcovs)) = dcovs;
325     }
326 
327   const bool status = Q.save(name, arma_binary);
328 
329   return status;
330   }
331 
332 
333 
334 template<typename eT>
335 inline
336 Col<eT>
generate() const337 gmm_diag<eT>::generate() const
338   {
339   arma_extra_debug_sigprint();
340 
341   const uword N_dims = means.n_rows;
342   const uword N_gaus = means.n_cols;
343 
344   Col<eT> out( ((N_gaus > 0) ? N_dims : uword(0)), fill::randn );
345 
346   if(N_gaus > 0)
347     {
348     const double val = randu<double>();
349 
350     double csum    = double(0);
351     uword  gaus_id = 0;
352 
353     for(uword j=0; j < N_gaus; ++j)
354       {
355       csum += hefts[j];
356 
357       if(val <= csum)  { gaus_id = j; break; }
358       }
359 
360     out %= sqrt(dcovs.col(gaus_id));
361     out += means.col(gaus_id);
362     }
363 
364   return out;
365   }
366 
367 
368 
369 template<typename eT>
370 inline
371 Mat<eT>
generate(const uword N_vec) const372 gmm_diag<eT>::generate(const uword N_vec) const
373   {
374   arma_extra_debug_sigprint();
375 
376   const uword N_dims = means.n_rows;
377   const uword N_gaus = means.n_cols;
378 
379   Mat<eT> out( ( (N_gaus > 0) ? N_dims : uword(0) ), N_vec, fill::randn );
380 
381   if(N_gaus > 0)
382     {
383     const eT* hefts_mem = hefts.memptr();
384 
385     const Mat<eT> sqrt_dcovs = sqrt(dcovs);
386 
387     for(uword i=0; i < N_vec; ++i)
388       {
389       const double val = randu<double>();
390 
391       double csum    = double(0);
392       uword  gaus_id = 0;
393 
394       for(uword j=0; j < N_gaus; ++j)
395         {
396         csum += hefts_mem[j];
397 
398         if(val <= csum)  { gaus_id = j; break; }
399         }
400 
401       subview_col<eT> out_col = out.col(i);
402 
403       out_col %= sqrt_dcovs.col(gaus_id);
404       out_col += means.col(gaus_id);
405       }
406     }
407 
408   return out;
409   }
410 
411 
412 
413 template<typename eT>
414 template<typename T1>
415 inline
416 eT
log_p(const T1 & expr,const gmm_empty_arg & junk1,typename enable_if<((is_arma_type<T1>::value)&& (resolves_to_colvector<T1>::value==true))>::result * junk2) const417 gmm_diag<eT>::log_p(const T1& expr, const gmm_empty_arg& junk1, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == true))>::result* junk2) const
418   {
419   arma_extra_debug_sigprint();
420   arma_ignore(junk1);
421   arma_ignore(junk2);
422 
423   const quasi_unwrap<T1> tmp(expr);
424 
425   arma_debug_check( (tmp.M.n_rows != means.n_rows), "gmm_diag::log_p(): incompatible dimensions" );
426 
427   return internal_scalar_log_p( tmp.M.memptr() );
428   }
429 
430 
431 
432 template<typename eT>
433 template<typename T1>
434 inline
435 eT
log_p(const T1 & expr,const uword gaus_id,typename enable_if<((is_arma_type<T1>::value)&& (resolves_to_colvector<T1>::value==true))>::result * junk2) const436 gmm_diag<eT>::log_p(const T1& expr, const uword gaus_id, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == true))>::result* junk2) const
437   {
438   arma_extra_debug_sigprint();
439   arma_ignore(junk2);
440 
441   const quasi_unwrap<T1> tmp(expr);
442 
443   arma_debug_check( (tmp.M.n_rows != means.n_rows), "gmm_diag::log_p(): incompatible dimensions" );
444 
445   arma_debug_check( (gaus_id >= means.n_cols), "gmm_diag::log_p(): specified gaussian is out of range" );
446 
447   return internal_scalar_log_p( tmp.M.memptr(), gaus_id );
448   }
449 
450 
451 
452 template<typename eT>
453 template<typename T1>
454 inline
455 Row<eT>
log_p(const T1 & expr,const gmm_empty_arg & junk1,typename enable_if<((is_arma_type<T1>::value)&& (resolves_to_colvector<T1>::value==false))>::result * junk2) const456 gmm_diag<eT>::log_p(const T1& expr, const gmm_empty_arg& junk1, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == false))>::result* junk2) const
457   {
458   arma_extra_debug_sigprint();
459   arma_ignore(junk1);
460   arma_ignore(junk2);
461 
462   const quasi_unwrap<T1> tmp(expr);
463 
464   const Mat<eT>& X = tmp.M;
465 
466   return internal_vec_log_p(X);
467   }
468 
469 
470 
471 template<typename eT>
472 template<typename T1>
473 inline
474 Row<eT>
log_p(const T1 & expr,const uword gaus_id,typename enable_if<((is_arma_type<T1>::value)&& (resolves_to_colvector<T1>::value==false))>::result * junk2) const475 gmm_diag<eT>::log_p(const T1& expr, const uword gaus_id, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == false))>::result* junk2) const
476   {
477   arma_extra_debug_sigprint();
478   arma_ignore(junk2);
479 
480   const quasi_unwrap<T1> tmp(expr);
481 
482   const Mat<eT>& X = tmp.M;
483 
484   return internal_vec_log_p(X, gaus_id);
485   }
486 
487 
488 
489 template<typename eT>
490 template<typename T1>
491 inline
492 eT
sum_log_p(const Base<eT,T1> & expr) const493 gmm_diag<eT>::sum_log_p(const Base<eT,T1>& expr) const
494   {
495   arma_extra_debug_sigprint();
496 
497   const quasi_unwrap<T1> tmp(expr.get_ref());
498 
499   const Mat<eT>& X = tmp.M;
500 
501   return internal_sum_log_p(X);
502   }
503 
504 
505 
506 template<typename eT>
507 template<typename T1>
508 inline
509 eT
sum_log_p(const Base<eT,T1> & expr,const uword gaus_id) const510 gmm_diag<eT>::sum_log_p(const Base<eT,T1>& expr, const uword gaus_id) const
511   {
512   arma_extra_debug_sigprint();
513 
514   const quasi_unwrap<T1> tmp(expr.get_ref());
515 
516   const Mat<eT>& X = tmp.M;
517 
518   return internal_sum_log_p(X, gaus_id);
519   }
520 
521 
522 
523 template<typename eT>
524 template<typename T1>
525 inline
526 eT
avg_log_p(const Base<eT,T1> & expr) const527 gmm_diag<eT>::avg_log_p(const Base<eT,T1>& expr) const
528   {
529   arma_extra_debug_sigprint();
530 
531   const quasi_unwrap<T1> tmp(expr.get_ref());
532 
533   const Mat<eT>& X = tmp.M;
534 
535   return internal_avg_log_p(X);
536   }
537 
538 
539 
540 template<typename eT>
541 template<typename T1>
542 inline
543 eT
avg_log_p(const Base<eT,T1> & expr,const uword gaus_id) const544 gmm_diag<eT>::avg_log_p(const Base<eT,T1>& expr, const uword gaus_id) const
545   {
546   arma_extra_debug_sigprint();
547 
548   const quasi_unwrap<T1> tmp(expr.get_ref());
549 
550   const Mat<eT>& X = tmp.M;
551 
552   return internal_avg_log_p(X, gaus_id);
553   }
554 
555 
556 
557 template<typename eT>
558 template<typename T1>
559 inline
560 uword
assign(const T1 & expr,const gmm_dist_mode & dist,typename enable_if<((is_arma_type<T1>::value)&& (resolves_to_colvector<T1>::value==true))>::result * junk) const561 gmm_diag<eT>::assign(const T1& expr, const gmm_dist_mode& dist, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == true))>::result* junk) const
562   {
563   arma_extra_debug_sigprint();
564   arma_ignore(junk);
565 
566   const quasi_unwrap<T1> tmp(expr);
567 
568   const Mat<eT>& X = tmp.M;
569 
570   return internal_scalar_assign(X, dist);
571   }
572 
573 
574 
575 template<typename eT>
576 template<typename T1>
577 inline
578 urowvec
assign(const T1 & expr,const gmm_dist_mode & dist,typename enable_if<((is_arma_type<T1>::value)&& (resolves_to_colvector<T1>::value==false))>::result * junk) const579 gmm_diag<eT>::assign(const T1& expr, const gmm_dist_mode& dist, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == false))>::result* junk) const
580   {
581   arma_extra_debug_sigprint();
582   arma_ignore(junk);
583 
584   urowvec out;
585 
586   const quasi_unwrap<T1> tmp(expr);
587 
588   const Mat<eT>& X = tmp.M;
589 
590   internal_vec_assign(out, X, dist);
591 
592   return out;
593   }
594 
595 
596 
597 template<typename eT>
598 template<typename T1>
599 inline
600 urowvec
raw_hist(const Base<eT,T1> & expr,const gmm_dist_mode & dist_mode) const601 gmm_diag<eT>::raw_hist(const Base<eT,T1>& expr, const gmm_dist_mode& dist_mode) const
602   {
603   arma_extra_debug_sigprint();
604 
605   const unwrap<T1>   tmp(expr.get_ref());
606   const Mat<eT>& X = tmp.M;
607 
608   arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::raw_hist(): incompatible dimensions" );
609 
610   arma_debug_check( ((dist_mode != eucl_dist) && (dist_mode != prob_dist)), "gmm_diag::raw_hist(): unsupported distance mode" );
611 
612   urowvec hist;
613 
614   internal_raw_hist(hist, X, dist_mode);
615 
616   return hist;
617   }
618 
619 
620 
621 template<typename eT>
622 template<typename T1>
623 inline
624 Row<eT>
norm_hist(const Base<eT,T1> & expr,const gmm_dist_mode & dist_mode) const625 gmm_diag<eT>::norm_hist(const Base<eT,T1>& expr, const gmm_dist_mode& dist_mode) const
626   {
627   arma_extra_debug_sigprint();
628 
629   const unwrap<T1>   tmp(expr.get_ref());
630   const Mat<eT>& X = tmp.M;
631 
632   arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::norm_hist(): incompatible dimensions" );
633 
634   arma_debug_check( ((dist_mode != eucl_dist) && (dist_mode != prob_dist)), "gmm_diag::norm_hist(): unsupported distance mode" );
635 
636   urowvec hist;
637 
638   internal_raw_hist(hist, X, dist_mode);
639 
640   const uword  hist_n_elem = hist.n_elem;
641   const uword* hist_mem    = hist.memptr();
642 
643   eT acc = eT(0);
644   for(uword i=0; i<hist_n_elem; ++i)  { acc += eT(hist_mem[i]); }
645 
646   if(acc == eT(0))  { acc = eT(1); }
647 
648   Row<eT> out(hist_n_elem, arma_nozeros_indicator());
649 
650   eT* out_mem = out.memptr();
651 
652   for(uword i=0; i<hist_n_elem; ++i)  { out_mem[i] = eT(hist_mem[i]) / acc; }
653 
654   return out;
655   }
656 
657 
658 
659 template<typename eT>
660 template<typename T1>
661 inline
662 bool
learn(const Base<eT,T1> & data,const uword N_gaus,const gmm_dist_mode & dist_mode,const gmm_seed_mode & seed_mode,const uword km_iter,const uword em_iter,const eT var_floor,const bool print_mode)663 gmm_diag<eT>::learn
664   (
665   const Base<eT,T1>&   data,
666   const uword          N_gaus,
667   const gmm_dist_mode& dist_mode,
668   const gmm_seed_mode& seed_mode,
669   const uword          km_iter,
670   const uword          em_iter,
671   const eT             var_floor,
672   const bool           print_mode
673   )
674   {
675   arma_extra_debug_sigprint();
676 
677   const bool dist_mode_ok = (dist_mode == eucl_dist) || (dist_mode == maha_dist);
678 
679   const bool seed_mode_ok = \
680        (seed_mode == keep_existing)
681     || (seed_mode == static_subset)
682     || (seed_mode == static_spread)
683     || (seed_mode == random_subset)
684     || (seed_mode == random_spread);
685 
686   arma_debug_check( (dist_mode_ok == false), "gmm_diag::learn(): dist_mode must be eucl_dist or maha_dist" );
687   arma_debug_check( (seed_mode_ok == false), "gmm_diag::learn(): unknown seed_mode"                        );
688   arma_debug_check( (var_floor < eT(0)    ), "gmm_diag::learn(): variance floor is negative"               );
689 
690   const unwrap<T1>   tmp_X(data.get_ref());
691   const Mat<eT>& X = tmp_X.M;
692 
693   if(X.is_empty()          )  { arma_debug_warn_level(3, "gmm_diag::learn(): given matrix is empty"             ); return false; }
694   if(X.is_finite() == false)  { arma_debug_warn_level(3, "gmm_diag::learn(): given matrix has non-finite values"); return false; }
695 
696   if(N_gaus == 0)  { reset(); return true; }
697 
698   if(dist_mode == maha_dist)
699     {
700     mah_aux = var(X,1,1);
701 
702     const uword mah_aux_n_elem = mah_aux.n_elem;
703           eT*   mah_aux_mem    = mah_aux.memptr();
704 
705     for(uword i=0; i < mah_aux_n_elem; ++i)
706       {
707       const eT val = mah_aux_mem[i];
708 
709       mah_aux_mem[i] = ((val != eT(0)) && arma_isfinite(val)) ? eT(1) / val : eT(1);
710       }
711     }
712 
713 
714   // copy current model, in case of failure by k-means and/or EM
715 
716   const gmm_diag<eT> orig = (*this);
717 
718 
719   // initial means
720 
721   if(seed_mode == keep_existing)
722     {
723     if(means.is_empty()        )  { arma_debug_warn_level(3, "gmm_diag::learn(): no existing means"      ); return false; }
724     if(X.n_rows != means.n_rows)  { arma_debug_warn_level(3, "gmm_diag::learn(): dimensionality mismatch"); return false; }
725 
726     // TODO: also check for number of vectors?
727     }
728   else
729     {
730     if(X.n_cols < N_gaus)  { arma_debug_warn_level(3, "gmm_diag::learn(): number of vectors is less than number of gaussians"); return false; }
731 
732     reset(X.n_rows, N_gaus);
733 
734     if(print_mode)  { get_cout_stream() << "gmm_diag::learn(): generating initial means\n"; get_cout_stream().flush(); }
735 
736          if(dist_mode == eucl_dist)  { generate_initial_means<1>(X, seed_mode); }
737     else if(dist_mode == maha_dist)  { generate_initial_means<2>(X, seed_mode); }
738     }
739 
740 
741   // k-means
742 
743   if(km_iter > 0)
744     {
745     const arma_ostream_state stream_state(get_cout_stream());
746 
747     bool status = false;
748 
749          if(dist_mode == eucl_dist)  { status = km_iterate<1>(X, km_iter, print_mode, "gmm_diag::learn(): k-means"); }
750     else if(dist_mode == maha_dist)  { status = km_iterate<2>(X, km_iter, print_mode, "gmm_diag::learn(): k-means"); }
751 
752     stream_state.restore(get_cout_stream());
753 
754     if(status == false)  { arma_debug_warn_level(3, "gmm_diag::learn(): k-means algorithm failed; not enough data, or too many gaussians requested"); init(orig); return false; }
755     }
756 
757 
758   // initial dcovs
759 
760   const eT var_floor_actual = (eT(var_floor) > eT(0)) ? eT(var_floor) : std::numeric_limits<eT>::min();
761 
762   if(seed_mode != keep_existing)
763     {
764     if(print_mode)  { get_cout_stream() << "gmm_diag::learn(): generating initial covariances\n"; get_cout_stream().flush(); }
765 
766          if(dist_mode == eucl_dist)  { generate_initial_params<1>(X, var_floor_actual); }
767     else if(dist_mode == maha_dist)  { generate_initial_params<2>(X, var_floor_actual); }
768     }
769 
770 
771   // EM algorithm
772 
773   if(em_iter > 0)
774     {
775     const arma_ostream_state stream_state(get_cout_stream());
776 
777     const bool status = em_iterate(X, em_iter, var_floor_actual, print_mode);
778 
779     stream_state.restore(get_cout_stream());
780 
781     if(status == false)  { arma_debug_warn_level(3, "gmm_diag::learn(): EM algorithm failed"); init(orig); return false; }
782     }
783 
784   mah_aux.reset();
785 
786   init_constants();
787 
788   return true;
789   }
790 
791 
792 
793 template<typename eT>
794 template<typename T1>
795 inline
796 bool
kmeans_wrapper(Mat<eT> & user_means,const Base<eT,T1> & data,const uword N_gaus,const gmm_seed_mode & seed_mode,const uword km_iter,const bool print_mode)797 gmm_diag<eT>::kmeans_wrapper
798   (
799         Mat<eT>&       user_means,
800   const Base<eT,T1>&   data,
801   const uword          N_gaus,
802   const gmm_seed_mode& seed_mode,
803   const uword          km_iter,
804   const bool           print_mode
805   )
806   {
807   arma_extra_debug_sigprint();
808 
809   const bool seed_mode_ok = \
810        (seed_mode == keep_existing)
811     || (seed_mode == static_subset)
812     || (seed_mode == static_spread)
813     || (seed_mode == random_subset)
814     || (seed_mode == random_spread);
815 
816   arma_debug_check( (seed_mode_ok == false), "kmeans(): unknown seed_mode" );
817 
818   const unwrap<T1>   tmp_X(data.get_ref());
819   const Mat<eT>& X = tmp_X.M;
820 
821   if(X.is_empty()          )  { arma_debug_warn_level(3, "kmeans(): given matrix is empty"             ); return false; }
822   if(X.is_finite() == false)  { arma_debug_warn_level(3, "kmeans(): given matrix has non-finite values"); return false; }
823 
824   if(N_gaus == 0)  { reset(); return true; }
825 
826 
827   // initial means
828 
829   if(seed_mode == keep_existing)
830     {
831     access::rw(means) = user_means;
832 
833     if(means.is_empty()        )  { arma_debug_warn_level(3, "kmeans(): no existing means"      ); return false; }
834     if(X.n_rows != means.n_rows)  { arma_debug_warn_level(3, "kmeans(): dimensionality mismatch"); return false; }
835 
836     // TODO: also check for number of vectors?
837     }
838   else
839     {
840     if(X.n_cols < N_gaus)  { arma_debug_warn_level(3, "kmeans(): number of vectors is less than number of means"); return false; }
841 
842     access::rw(means).zeros(X.n_rows, N_gaus);
843 
844     if(print_mode)  { get_cout_stream() << "kmeans(): generating initial means\n"; }
845 
846     generate_initial_means<1>(X, seed_mode);
847     }
848 
849 
850   // k-means
851 
852   if(km_iter > 0)
853     {
854     const arma_ostream_state stream_state(get_cout_stream());
855 
856     bool status = false;
857 
858     status = km_iterate<1>(X, km_iter, print_mode, "kmeans()");
859 
860     stream_state.restore(get_cout_stream());
861 
862     if(status == false)  { arma_debug_warn_level(3, "kmeans(): clustering failed; not enough data, or too many means requested"); return false; }
863     }
864 
865   return true;
866   }
867 
868 
869 
870 //
871 //
872 //
873 
874 
875 
876 template<typename eT>
877 inline
878 void
init(const gmm_diag<eT> & x)879 gmm_diag<eT>::init(const gmm_diag<eT>& x)
880   {
881   arma_extra_debug_sigprint();
882 
883   gmm_diag<eT>& t = *this;
884 
885   if(&t != &x)
886     {
887     access::rw(t.means) = x.means;
888     access::rw(t.dcovs) = x.dcovs;
889     access::rw(t.hefts) = x.hefts;
890 
891     init_constants();
892     }
893   }
894 
895 
896 
897 template<typename eT>
898 inline
899 void
init(const gmm_full<eT> & x)900 gmm_diag<eT>::init(const gmm_full<eT>& x)
901   {
902   arma_extra_debug_sigprint();
903 
904   access::rw(hefts) = x.hefts;
905   access::rw(means) = x.means;
906 
907   const uword N_dims = x.means.n_rows;
908   const uword N_gaus = x.means.n_cols;
909 
910   access::rw(dcovs).zeros(N_dims,N_gaus);
911 
912   for(uword g=0; g < N_gaus; ++g)
913     {
914     const Mat<eT>& fcov = x.fcovs.slice(g);
915 
916     eT* dcov_mem = access::rw(dcovs).colptr(g);
917 
918     for(uword d=0; d < N_dims; ++d)
919       {
920       dcov_mem[d] = fcov.at(d,d);
921       }
922     }
923 
924   init_constants();
925   }
926 
927 
928 
929 template<typename eT>
930 inline
931 void
init(const uword in_n_dims,const uword in_n_gaus)932 gmm_diag<eT>::init(const uword in_n_dims, const uword in_n_gaus)
933   {
934   arma_extra_debug_sigprint();
935 
936   access::rw(means).zeros(in_n_dims, in_n_gaus);
937 
938   access::rw(dcovs).ones(in_n_dims, in_n_gaus);
939 
940   access::rw(hefts).set_size(in_n_gaus);
941 
942   access::rw(hefts).fill(eT(1) / eT(in_n_gaus));
943 
944   init_constants();
945   }
946 
947 
948 
949 template<typename eT>
950 inline
951 void
init_constants()952 gmm_diag<eT>::init_constants()
953   {
954   arma_extra_debug_sigprint();
955 
956   const uword N_dims = means.n_rows;
957   const uword N_gaus = means.n_cols;
958 
959   //
960 
961   inv_dcovs.copy_size(dcovs);
962 
963   const eT*     dcovs_mem =     dcovs.memptr();
964         eT* inv_dcovs_mem = inv_dcovs.memptr();
965 
966   const uword dcovs_n_elem = dcovs.n_elem;
967 
968   for(uword i=0; i < dcovs_n_elem; ++i)
969     {
970     inv_dcovs_mem[i] = eT(1) / (std::max)( dcovs_mem[i], std::numeric_limits<eT>::min() );
971     }
972 
973   //
974 
975   const eT tmp = (eT(N_dims)/eT(2)) * std::log(eT(2) * Datum<eT>::pi);
976 
977   log_det_etc.set_size(N_gaus);
978 
979   for(uword g=0; g < N_gaus; ++g)
980     {
981     const eT* dcovs_colmem = dcovs.colptr(g);
982 
983     eT log_det_val = eT(0);
984 
985     for(uword d=0; d < N_dims; ++d)
986       {
987       log_det_val += std::log( (std::max)( dcovs_colmem[d], std::numeric_limits<eT>::min() ) );
988       }
989 
990     log_det_etc[g] = eT(-1) * ( tmp + eT(0.5) * log_det_val );
991     }
992 
993   //
994 
995   eT* hefts_mem = access::rw(hefts).memptr();
996 
997   for(uword g=0; g < N_gaus; ++g)
998     {
999     hefts_mem[g] = (std::max)( hefts_mem[g], std::numeric_limits<eT>::min() );
1000     }
1001 
1002   log_hefts = log(hefts);
1003   }
1004 
1005 
1006 
1007 template<typename eT>
1008 inline
1009 umat
internal_gen_boundaries(const uword N) const1010 gmm_diag<eT>::internal_gen_boundaries(const uword N) const
1011   {
1012   arma_extra_debug_sigprint();
1013 
1014   #if defined(ARMA_USE_OPENMP)
1015     const uword n_threads_avail = (omp_in_parallel()) ? uword(1) : uword(omp_get_max_threads());
1016     const uword n_threads       = (n_threads_avail > 0) ? ( (n_threads_avail <= N) ? n_threads_avail : 1 ) : 1;
1017   #else
1018     static constexpr uword n_threads = 1;
1019   #endif
1020 
1021   // get_cout_stream() << "gmm_diag::internal_gen_boundaries(): n_threads: " << n_threads << '\n';
1022 
1023   umat boundaries(2, n_threads, arma_nozeros_indicator());
1024 
1025   if(N > 0)
1026     {
1027     const uword chunk_size = N / n_threads;
1028 
1029     uword count = 0;
1030 
1031     for(uword t=0; t<n_threads; t++)
1032       {
1033       boundaries.at(0,t) = count;
1034 
1035       count += chunk_size;
1036 
1037       boundaries.at(1,t) = count-1;
1038       }
1039 
1040     boundaries.at(1,n_threads-1) = N - 1;
1041     }
1042   else
1043     {
1044     boundaries.zeros();
1045     }
1046 
1047   // get_cout_stream() << "gmm_diag::internal_gen_boundaries(): boundaries: " << '\n' << boundaries << '\n';
1048 
1049   return boundaries;
1050   }
1051 
1052 
1053 
1054 template<typename eT>
1055 arma_hot
1056 inline
1057 eT
internal_scalar_log_p(const eT * x) const1058 gmm_diag<eT>::internal_scalar_log_p(const eT* x) const
1059   {
1060   arma_extra_debug_sigprint();
1061 
1062   const eT* log_hefts_mem = log_hefts.mem;
1063 
1064   const uword N_gaus = means.n_cols;
1065 
1066   if(N_gaus > 0)
1067     {
1068     eT log_sum = internal_scalar_log_p(x, 0) + log_hefts_mem[0];
1069 
1070     for(uword g=1; g < N_gaus; ++g)
1071       {
1072       const eT tmp = internal_scalar_log_p(x, g) + log_hefts_mem[g];
1073 
1074       log_sum = log_add_exp(log_sum, tmp);
1075       }
1076 
1077     return log_sum;
1078     }
1079   else
1080     {
1081     return -Datum<eT>::inf;
1082     }
1083   }
1084 
1085 
1086 
1087 template<typename eT>
1088 arma_hot
1089 inline
1090 eT
internal_scalar_log_p(const eT * x,const uword g) const1091 gmm_diag<eT>::internal_scalar_log_p(const eT* x, const uword g) const
1092   {
1093   arma_extra_debug_sigprint();
1094 
1095   const eT*     mean =     means.colptr(g);
1096   const eT* inv_dcov = inv_dcovs.colptr(g);
1097 
1098   const uword N_dims = means.n_rows;
1099 
1100   eT val_i = eT(0);
1101   eT val_j = eT(0);
1102 
1103   uword i,j;
1104 
1105   for(i=0, j=1; j<N_dims; i+=2, j+=2)
1106     {
1107     eT tmp_i = x[i];
1108     eT tmp_j = x[j];
1109 
1110     tmp_i -= mean[i];
1111     tmp_j -= mean[j];
1112 
1113     val_i += (tmp_i*tmp_i) * inv_dcov[i];
1114     val_j += (tmp_j*tmp_j) * inv_dcov[j];
1115     }
1116 
1117   if(i < N_dims)
1118     {
1119     const eT tmp = x[i] - mean[i];
1120 
1121     val_i += (tmp*tmp) * inv_dcov[i];
1122     }
1123 
1124   return eT(-0.5)*(val_i + val_j) + log_det_etc.mem[g];
1125   }
1126 
1127 
1128 
1129 template<typename eT>
1130 inline
1131 Row<eT>
internal_vec_log_p(const Mat<eT> & X) const1132 gmm_diag<eT>::internal_vec_log_p(const Mat<eT>& X) const
1133   {
1134   arma_extra_debug_sigprint();
1135 
1136   arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::log_p(): incompatible dimensions" );
1137 
1138   const uword N = X.n_cols;
1139 
1140   Row<eT> out(N, arma_nozeros_indicator());
1141 
1142   if(N > 0)
1143     {
1144     #if defined(ARMA_USE_OPENMP)
1145       {
1146       const umat boundaries = internal_gen_boundaries(N);
1147 
1148       const uword n_threads = boundaries.n_cols;
1149 
1150       #pragma omp parallel for schedule(static)
1151       for(uword t=0; t < n_threads; ++t)
1152         {
1153         const uword start_index = boundaries.at(0,t);
1154         const uword   end_index = boundaries.at(1,t);
1155 
1156         eT* out_mem = out.memptr();
1157 
1158         for(uword i=start_index; i <= end_index; ++i)
1159           {
1160           out_mem[i] = internal_scalar_log_p( X.colptr(i) );
1161           }
1162         }
1163       }
1164     #else
1165       {
1166       eT* out_mem = out.memptr();
1167 
1168       for(uword i=0; i < N; ++i)
1169         {
1170         out_mem[i] = internal_scalar_log_p( X.colptr(i) );
1171         }
1172       }
1173     #endif
1174     }
1175 
1176   return out;
1177   }
1178 
1179 
1180 
1181 template<typename eT>
1182 inline
1183 Row<eT>
internal_vec_log_p(const Mat<eT> & X,const uword gaus_id) const1184 gmm_diag<eT>::internal_vec_log_p(const Mat<eT>& X, const uword gaus_id) const
1185   {
1186   arma_extra_debug_sigprint();
1187 
1188   arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::log_p(): incompatible dimensions" );
1189   arma_debug_check( (gaus_id  >= means.n_cols), "gmm_diag::log_p(): specified gaussian is out of range" );
1190 
1191   const uword N = X.n_cols;
1192 
1193   Row<eT> out(N, arma_nozeros_indicator());
1194 
1195   if(N > 0)
1196     {
1197     #if defined(ARMA_USE_OPENMP)
1198       {
1199       const umat boundaries = internal_gen_boundaries(N);
1200 
1201       const uword n_threads = boundaries.n_cols;
1202 
1203       #pragma omp parallel for schedule(static)
1204       for(uword t=0; t < n_threads; ++t)
1205         {
1206         const uword start_index = boundaries.at(0,t);
1207         const uword   end_index = boundaries.at(1,t);
1208 
1209         eT* out_mem = out.memptr();
1210 
1211         for(uword i=start_index; i <= end_index; ++i)
1212           {
1213           out_mem[i] = internal_scalar_log_p( X.colptr(i), gaus_id );
1214           }
1215         }
1216       }
1217     #else
1218       {
1219       eT* out_mem = out.memptr();
1220 
1221       for(uword i=0; i < N; ++i)
1222         {
1223         out_mem[i] = internal_scalar_log_p( X.colptr(i), gaus_id );
1224         }
1225       }
1226     #endif
1227     }
1228 
1229   return out;
1230   }
1231 
1232 
1233 
1234 template<typename eT>
1235 inline
1236 eT
internal_sum_log_p(const Mat<eT> & X) const1237 gmm_diag<eT>::internal_sum_log_p(const Mat<eT>& X) const
1238   {
1239   arma_extra_debug_sigprint();
1240 
1241   arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::sum_log_p(): incompatible dimensions" );
1242 
1243   const uword N = X.n_cols;
1244 
1245   if(N == 0)  { return (-Datum<eT>::inf); }
1246 
1247 
1248   #if defined(ARMA_USE_OPENMP)
1249     {
1250     const umat boundaries = internal_gen_boundaries(N);
1251 
1252     const uword n_threads = boundaries.n_cols;
1253 
1254     Col<eT> t_accs(n_threads, arma_zeros_indicator());
1255 
1256     #pragma omp parallel for schedule(static)
1257     for(uword t=0; t < n_threads; ++t)
1258       {
1259       const uword start_index = boundaries.at(0,t);
1260       const uword   end_index = boundaries.at(1,t);
1261 
1262       eT t_acc = eT(0);
1263 
1264       for(uword i=start_index; i <= end_index; ++i)
1265         {
1266         t_acc += internal_scalar_log_p( X.colptr(i) );
1267         }
1268 
1269       t_accs[t] = t_acc;
1270       }
1271 
1272     return eT(accu(t_accs));
1273     }
1274   #else
1275     {
1276     eT acc = eT(0);
1277 
1278     for(uword i=0; i<N; ++i)
1279       {
1280       acc += internal_scalar_log_p( X.colptr(i) );
1281       }
1282 
1283     return acc;
1284     }
1285   #endif
1286   }
1287 
1288 
1289 
1290 template<typename eT>
1291 inline
1292 eT
internal_sum_log_p(const Mat<eT> & X,const uword gaus_id) const1293 gmm_diag<eT>::internal_sum_log_p(const Mat<eT>& X, const uword gaus_id) const
1294   {
1295   arma_extra_debug_sigprint();
1296 
1297   arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::sum_log_p(): incompatible dimensions"            );
1298   arma_debug_check( (gaus_id  >= means.n_cols), "gmm_diag::sum_log_p(): specified gaussian is out of range" );
1299 
1300   const uword N = X.n_cols;
1301 
1302   if(N == 0)  { return (-Datum<eT>::inf); }
1303 
1304 
1305   #if defined(ARMA_USE_OPENMP)
1306     {
1307     const umat boundaries = internal_gen_boundaries(N);
1308 
1309     const uword n_threads = boundaries.n_cols;
1310 
1311     Col<eT> t_accs(n_threads, arma_zeros_indicator());
1312 
1313     #pragma omp parallel for schedule(static)
1314     for(uword t=0; t < n_threads; ++t)
1315       {
1316       const uword start_index = boundaries.at(0,t);
1317       const uword   end_index = boundaries.at(1,t);
1318 
1319       eT t_acc = eT(0);
1320 
1321       for(uword i=start_index; i <= end_index; ++i)
1322         {
1323         t_acc += internal_scalar_log_p( X.colptr(i), gaus_id );
1324         }
1325 
1326       t_accs[t] = t_acc;
1327       }
1328 
1329     return eT(accu(t_accs));
1330     }
1331   #else
1332     {
1333     eT acc = eT(0);
1334 
1335     for(uword i=0; i<N; ++i)
1336       {
1337       acc += internal_scalar_log_p( X.colptr(i), gaus_id );
1338       }
1339 
1340     return acc;
1341     }
1342   #endif
1343   }
1344 
1345 
1346 
1347 template<typename eT>
1348 inline
1349 eT
internal_avg_log_p(const Mat<eT> & X) const1350 gmm_diag<eT>::internal_avg_log_p(const Mat<eT>& X) const
1351   {
1352   arma_extra_debug_sigprint();
1353 
1354   arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::avg_log_p(): incompatible dimensions" );
1355 
1356   const uword N = X.n_cols;
1357 
1358   if(N == 0)  { return (-Datum<eT>::inf); }
1359 
1360 
1361   #if defined(ARMA_USE_OPENMP)
1362     {
1363     const umat boundaries = internal_gen_boundaries(N);
1364 
1365     const uword n_threads = boundaries.n_cols;
1366 
1367     field< running_mean_scalar<eT> > t_running_means(n_threads);
1368 
1369 
1370     #pragma omp parallel for schedule(static)
1371     for(uword t=0; t < n_threads; ++t)
1372       {
1373       const uword start_index = boundaries.at(0,t);
1374       const uword   end_index = boundaries.at(1,t);
1375 
1376       running_mean_scalar<eT>& current_running_mean = t_running_means[t];
1377 
1378       for(uword i=start_index; i <= end_index; ++i)
1379         {
1380         current_running_mean( internal_scalar_log_p( X.colptr(i) ) );
1381         }
1382       }
1383 
1384 
1385     eT avg = eT(0);
1386 
1387     for(uword t=0; t < n_threads; ++t)
1388       {
1389       running_mean_scalar<eT>& current_running_mean = t_running_means[t];
1390 
1391       const eT w = eT(current_running_mean.count()) / eT(N);
1392 
1393       avg += w * current_running_mean.mean();
1394       }
1395 
1396     return avg;
1397     }
1398   #else
1399     {
1400     running_mean_scalar<eT> running_mean;
1401 
1402     for(uword i=0; i<N; ++i)
1403       {
1404       running_mean( internal_scalar_log_p( X.colptr(i) ) );
1405       }
1406 
1407     return running_mean.mean();
1408     }
1409   #endif
1410   }
1411 
1412 
1413 
1414 template<typename eT>
1415 inline
1416 eT
internal_avg_log_p(const Mat<eT> & X,const uword gaus_id) const1417 gmm_diag<eT>::internal_avg_log_p(const Mat<eT>& X, const uword gaus_id) const
1418   {
1419   arma_extra_debug_sigprint();
1420 
1421   arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::avg_log_p(): incompatible dimensions"            );
1422   arma_debug_check( (gaus_id  >= means.n_cols), "gmm_diag::avg_log_p(): specified gaussian is out of range" );
1423 
1424   const uword N = X.n_cols;
1425 
1426   if(N == 0)  { return (-Datum<eT>::inf); }
1427 
1428 
1429   #if defined(ARMA_USE_OPENMP)
1430     {
1431     const umat boundaries = internal_gen_boundaries(N);
1432 
1433     const uword n_threads = boundaries.n_cols;
1434 
1435     field< running_mean_scalar<eT> > t_running_means(n_threads);
1436 
1437 
1438     #pragma omp parallel for schedule(static)
1439     for(uword t=0; t < n_threads; ++t)
1440       {
1441       const uword start_index = boundaries.at(0,t);
1442       const uword   end_index = boundaries.at(1,t);
1443 
1444       running_mean_scalar<eT>& current_running_mean = t_running_means[t];
1445 
1446       for(uword i=start_index; i <= end_index; ++i)
1447         {
1448         current_running_mean( internal_scalar_log_p( X.colptr(i), gaus_id) );
1449         }
1450       }
1451 
1452 
1453     eT avg = eT(0);
1454 
1455     for(uword t=0; t < n_threads; ++t)
1456       {
1457       running_mean_scalar<eT>& current_running_mean = t_running_means[t];
1458 
1459       const eT w = eT(current_running_mean.count()) / eT(N);
1460 
1461       avg += w * current_running_mean.mean();
1462       }
1463 
1464     return avg;
1465     }
1466   #else
1467     {
1468     running_mean_scalar<eT> running_mean;
1469 
1470     for(uword i=0; i<N; ++i)
1471       {
1472       running_mean( internal_scalar_log_p( X.colptr(i), gaus_id ) );
1473       }
1474 
1475     return running_mean.mean();
1476     }
1477   #endif
1478   }
1479 
1480 
1481 
1482 template<typename eT>
1483 inline
1484 uword
internal_scalar_assign(const Mat<eT> & X,const gmm_dist_mode & dist_mode) const1485 gmm_diag<eT>::internal_scalar_assign(const Mat<eT>& X, const gmm_dist_mode& dist_mode) const
1486   {
1487   arma_extra_debug_sigprint();
1488 
1489   const uword N_dims = means.n_rows;
1490   const uword N_gaus = means.n_cols;
1491 
1492   arma_debug_check( (X.n_rows != N_dims), "gmm_diag::assign(): incompatible dimensions" );
1493   arma_debug_check( (N_gaus == 0),        "gmm_diag::assign(): model has no means"      );
1494 
1495   const eT* X_mem = X.colptr(0);
1496 
1497   if(dist_mode == eucl_dist)
1498     {
1499     eT    best_dist = Datum<eT>::inf;
1500     uword best_g    = 0;
1501 
1502     for(uword g=0; g < N_gaus; ++g)
1503       {
1504       const eT tmp_dist = distance<eT,1>::eval(N_dims, X_mem, means.colptr(g), X_mem);
1505 
1506       if(tmp_dist <= best_dist)  { best_dist = tmp_dist;  best_g = g; }
1507       }
1508 
1509     return best_g;
1510     }
1511   else
1512   if(dist_mode == prob_dist)
1513     {
1514     const eT* log_hefts_mem = log_hefts.memptr();
1515 
1516     eT    best_p = -Datum<eT>::inf;
1517     uword best_g = 0;
1518 
1519     for(uword g=0; g < N_gaus; ++g)
1520       {
1521       const eT tmp_p = internal_scalar_log_p(X_mem, g) + log_hefts_mem[g];
1522 
1523       if(tmp_p >= best_p)  { best_p = tmp_p;  best_g = g; }
1524       }
1525 
1526     return best_g;
1527     }
1528   else
1529     {
1530     arma_debug_check(true, "gmm_diag::assign(): unsupported distance mode");
1531     }
1532 
1533   return uword(0);
1534   }
1535 
1536 
1537 
1538 template<typename eT>
1539 inline
1540 void
internal_vec_assign(urowvec & out,const Mat<eT> & X,const gmm_dist_mode & dist_mode) const1541 gmm_diag<eT>::internal_vec_assign(urowvec& out, const Mat<eT>& X, const gmm_dist_mode& dist_mode) const
1542   {
1543   arma_extra_debug_sigprint();
1544 
1545   const uword N_dims = means.n_rows;
1546   const uword N_gaus = means.n_cols;
1547 
1548   arma_debug_check( (X.n_rows != N_dims), "gmm_diag::assign(): incompatible dimensions" );
1549 
1550   const uword X_n_cols = (N_gaus > 0) ? X.n_cols : 0;
1551 
1552   out.set_size(1,X_n_cols);
1553 
1554   uword* out_mem = out.memptr();
1555 
1556   if(dist_mode == eucl_dist)
1557     {
1558     #if defined(ARMA_USE_OPENMP)
1559       {
1560       #pragma omp parallel for schedule(static)
1561       for(uword i=0; i<X_n_cols; ++i)
1562         {
1563         const eT* X_colptr = X.colptr(i);
1564 
1565         eT    best_dist = Datum<eT>::inf;
1566         uword best_g    = 0;
1567 
1568         for(uword g=0; g<N_gaus; ++g)
1569           {
1570           const eT tmp_dist = distance<eT,1>::eval(N_dims, X_colptr, means.colptr(g), X_colptr);
1571 
1572           if(tmp_dist <= best_dist)  { best_dist = tmp_dist;  best_g = g; }
1573           }
1574 
1575         out_mem[i] = best_g;
1576         }
1577       }
1578     #else
1579       {
1580       for(uword i=0; i<X_n_cols; ++i)
1581         {
1582         const eT* X_colptr = X.colptr(i);
1583 
1584         eT    best_dist = Datum<eT>::inf;
1585         uword best_g    = 0;
1586 
1587         for(uword g=0; g<N_gaus; ++g)
1588           {
1589           const eT tmp_dist = distance<eT,1>::eval(N_dims, X_colptr, means.colptr(g), X_colptr);
1590 
1591           if(tmp_dist <= best_dist)  { best_dist = tmp_dist;  best_g = g; }
1592           }
1593 
1594         out_mem[i] = best_g;
1595         }
1596       }
1597     #endif
1598     }
1599   else
1600   if(dist_mode == prob_dist)
1601     {
1602     #if defined(ARMA_USE_OPENMP)
1603       {
1604       const eT* log_hefts_mem = log_hefts.memptr();
1605 
1606       #pragma omp parallel for schedule(static)
1607       for(uword i=0; i<X_n_cols; ++i)
1608         {
1609         const eT* X_colptr = X.colptr(i);
1610 
1611         eT    best_p = -Datum<eT>::inf;
1612         uword best_g = 0;
1613 
1614         for(uword g=0; g<N_gaus; ++g)
1615           {
1616           const eT tmp_p = internal_scalar_log_p(X_colptr, g) + log_hefts_mem[g];
1617 
1618           if(tmp_p >= best_p)  { best_p = tmp_p;  best_g = g; }
1619           }
1620 
1621         out_mem[i] = best_g;
1622         }
1623       }
1624     #else
1625       {
1626       const eT* log_hefts_mem = log_hefts.memptr();
1627 
1628       for(uword i=0; i<X_n_cols; ++i)
1629         {
1630         const eT* X_colptr = X.colptr(i);
1631 
1632         eT    best_p = -Datum<eT>::inf;
1633         uword best_g = 0;
1634 
1635         for(uword g=0; g<N_gaus; ++g)
1636           {
1637           const eT tmp_p = internal_scalar_log_p(X_colptr, g) + log_hefts_mem[g];
1638 
1639           if(tmp_p >= best_p)  { best_p = tmp_p;  best_g = g; }
1640           }
1641 
1642         out_mem[i] = best_g;
1643         }
1644       }
1645     #endif
1646     }
1647   else
1648     {
1649     arma_debug_check(true, "gmm_diag::assign(): unsupported distance mode");
1650     }
1651   }
1652 
1653 
1654 
1655 
1656 template<typename eT>
1657 inline
1658 void
internal_raw_hist(urowvec & hist,const Mat<eT> & X,const gmm_dist_mode & dist_mode) const1659 gmm_diag<eT>::internal_raw_hist(urowvec& hist, const Mat<eT>& X, const gmm_dist_mode& dist_mode) const
1660   {
1661   arma_extra_debug_sigprint();
1662 
1663   const uword N_dims = means.n_rows;
1664   const uword N_gaus = means.n_cols;
1665 
1666   const uword X_n_cols = X.n_cols;
1667 
1668   hist.zeros(N_gaus);
1669 
1670   if(N_gaus == 0)  { return; }
1671 
1672   #if defined(ARMA_USE_OPENMP)
1673     {
1674     const umat boundaries = internal_gen_boundaries(X_n_cols);
1675 
1676     const uword n_threads = boundaries.n_cols;
1677 
1678     field<urowvec> thread_hist(n_threads);
1679 
1680     for(uword t=0; t < n_threads; ++t)  { thread_hist(t).zeros(N_gaus); }
1681 
1682 
1683     if(dist_mode == eucl_dist)
1684       {
1685       #pragma omp parallel for schedule(static)
1686       for(uword t=0; t < n_threads; ++t)
1687         {
1688         uword* thread_hist_mem = thread_hist(t).memptr();
1689 
1690         const uword start_index = boundaries.at(0,t);
1691         const uword   end_index = boundaries.at(1,t);
1692 
1693         for(uword i=start_index; i <= end_index; ++i)
1694           {
1695           const eT* X_colptr = X.colptr(i);
1696 
1697           eT    best_dist = Datum<eT>::inf;
1698           uword best_g    = 0;
1699 
1700           for(uword g=0; g < N_gaus; ++g)
1701             {
1702             const eT tmp_dist = distance<eT,1>::eval(N_dims, X_colptr, means.colptr(g), X_colptr);
1703 
1704             if(tmp_dist <= best_dist)  { best_dist = tmp_dist;  best_g = g; }
1705             }
1706 
1707           thread_hist_mem[best_g]++;
1708           }
1709         }
1710       }
1711     else
1712     if(dist_mode == prob_dist)
1713       {
1714       const eT* log_hefts_mem = log_hefts.memptr();
1715 
1716       #pragma omp parallel for schedule(static)
1717       for(uword t=0; t < n_threads; ++t)
1718         {
1719         uword* thread_hist_mem = thread_hist(t).memptr();
1720 
1721         const uword start_index = boundaries.at(0,t);
1722         const uword   end_index = boundaries.at(1,t);
1723 
1724         for(uword i=start_index; i <= end_index; ++i)
1725           {
1726           const eT* X_colptr = X.colptr(i);
1727 
1728           eT    best_p = -Datum<eT>::inf;
1729           uword best_g = 0;
1730 
1731           for(uword g=0; g < N_gaus; ++g)
1732             {
1733             const eT tmp_p = internal_scalar_log_p(X_colptr, g) + log_hefts_mem[g];
1734 
1735             if(tmp_p >= best_p)  { best_p = tmp_p;  best_g = g; }
1736             }
1737 
1738           thread_hist_mem[best_g]++;
1739           }
1740         }
1741       }
1742 
1743     // reduction
1744     hist = thread_hist(0);
1745 
1746     for(uword t=1; t < n_threads; ++t)
1747       {
1748       hist += thread_hist(t);
1749       }
1750     }
1751   #else
1752     {
1753     uword* hist_mem = hist.memptr();
1754 
1755     if(dist_mode == eucl_dist)
1756       {
1757       for(uword i=0; i<X_n_cols; ++i)
1758         {
1759         const eT* X_colptr = X.colptr(i);
1760 
1761         eT    best_dist = Datum<eT>::inf;
1762         uword best_g    = 0;
1763 
1764         for(uword g=0; g < N_gaus; ++g)
1765           {
1766           const eT tmp_dist = distance<eT,1>::eval(N_dims, X_colptr, means.colptr(g), X_colptr);
1767 
1768           if(tmp_dist <= best_dist)  { best_dist = tmp_dist;  best_g = g; }
1769           }
1770 
1771         hist_mem[best_g]++;
1772         }
1773       }
1774     else
1775     if(dist_mode == prob_dist)
1776       {
1777       const eT* log_hefts_mem = log_hefts.memptr();
1778 
1779       for(uword i=0; i<X_n_cols; ++i)
1780         {
1781         const eT* X_colptr = X.colptr(i);
1782 
1783         eT    best_p = -Datum<eT>::inf;
1784         uword best_g = 0;
1785 
1786         for(uword g=0; g < N_gaus; ++g)
1787           {
1788           const eT tmp_p = internal_scalar_log_p(X_colptr, g) + log_hefts_mem[g];
1789 
1790           if(tmp_p >= best_p)  { best_p = tmp_p;  best_g = g; }
1791           }
1792 
1793         hist_mem[best_g]++;
1794         }
1795       }
1796     }
1797   #endif
1798   }
1799 
1800 
1801 
1802 template<typename eT>
1803 template<uword dist_id>
1804 inline
1805 void
generate_initial_means(const Mat<eT> & X,const gmm_seed_mode & seed_mode)1806 gmm_diag<eT>::generate_initial_means(const Mat<eT>& X, const gmm_seed_mode& seed_mode)
1807   {
1808   arma_extra_debug_sigprint();
1809 
1810   const uword N_dims = means.n_rows;
1811   const uword N_gaus = means.n_cols;
1812 
1813   if( (seed_mode == static_subset) || (seed_mode == random_subset) )
1814     {
1815     uvec initial_indices;
1816 
1817          if(seed_mode == static_subset)  { initial_indices = linspace<uvec>(0, X.n_cols-1, N_gaus); }
1818     else if(seed_mode == random_subset)  { initial_indices = randperm<uvec>(X.n_cols, N_gaus);      }
1819 
1820     // initial_indices.print("initial_indices:");
1821 
1822     access::rw(means) = X.cols(initial_indices);
1823     }
1824   else
1825   if( (seed_mode == static_spread) || (seed_mode == random_spread) )
1826     {
1827     // going through all of the samples can be extremely time consuming;
1828     // instead, if there are enough samples, randomly choose samples with probability 0.1
1829 
1830     const bool  use_sampling = ((X.n_cols/uword(100)) > N_gaus);
1831     const uword step         = (use_sampling) ? uword(10) : uword(1);
1832 
1833     uword start_index = 0;
1834 
1835          if(seed_mode == static_spread)  { start_index = X.n_cols / 2;                                         }
1836     else if(seed_mode == random_spread)  { start_index = as_scalar(randi<uvec>(1, distr_param(0,X.n_cols-1))); }
1837 
1838     access::rw(means).col(0) = X.unsafe_col(start_index);
1839 
1840     const eT* mah_aux_mem = mah_aux.memptr();
1841 
1842     running_stat<double> rs;
1843 
1844     for(uword g=1; g < N_gaus; ++g)
1845       {
1846       eT    max_dist = eT(0);
1847       uword best_i   = uword(0);
1848       uword start_i  = uword(0);
1849 
1850       if(use_sampling)
1851         {
1852         uword start_i_proposed = uword(0);
1853 
1854         if(seed_mode == static_spread)  { start_i_proposed = g % uword(10);                               }
1855         if(seed_mode == random_spread)  { start_i_proposed = as_scalar(randi<uvec>(1, distr_param(0,9))); }
1856 
1857         if(start_i_proposed < X.n_cols)  { start_i = start_i_proposed; }
1858         }
1859 
1860 
1861       for(uword i=start_i; i < X.n_cols; i += step)
1862         {
1863         rs.reset();
1864 
1865         const eT* X_colptr = X.colptr(i);
1866 
1867         bool ignore_i = false;
1868 
1869         // find the average distance between sample i and the means so far
1870         for(uword h = 0; h < g; ++h)
1871           {
1872           const eT dist = distance<eT,dist_id>::eval(N_dims, X_colptr, means.colptr(h), mah_aux_mem);
1873 
1874           // ignore sample already selected as a mean
1875           if(dist == eT(0))  { ignore_i = true; break; }
1876           else               { rs(dist);               }
1877           }
1878 
1879         if( (rs.mean() >= max_dist) && (ignore_i == false))
1880           {
1881           max_dist = eT(rs.mean()); best_i = i;
1882           }
1883         }
1884 
1885       // set the mean to the sample that is the furthest away from the means so far
1886       access::rw(means).col(g) = X.unsafe_col(best_i);
1887       }
1888     }
1889 
1890   // get_cout_stream() << "generate_initial_means():" << '\n';
1891   // means.print();
1892   }
1893 
1894 
1895 
1896 template<typename eT>
1897 template<uword dist_id>
1898 inline
1899 void
generate_initial_params(const Mat<eT> & X,const eT var_floor)1900 gmm_diag<eT>::generate_initial_params(const Mat<eT>& X, const eT var_floor)
1901   {
1902   arma_extra_debug_sigprint();
1903 
1904   const uword N_dims = means.n_rows;
1905   const uword N_gaus = means.n_cols;
1906 
1907   const eT* mah_aux_mem = mah_aux.memptr();
1908 
1909   const uword X_n_cols = X.n_cols;
1910 
1911   if(X_n_cols == 0)  { return; }
1912 
1913   // as the covariances are calculated via accumulators,
1914   // the means also need to be calculated via accumulators to ensure numerical consistency
1915 
1916   Mat<eT> acc_means(N_dims, N_gaus, arma_zeros_indicator());
1917   Mat<eT> acc_dcovs(N_dims, N_gaus, arma_zeros_indicator());
1918 
1919   Row<uword> acc_hefts(N_gaus, arma_zeros_indicator());
1920 
1921   uword* acc_hefts_mem = acc_hefts.memptr();
1922 
1923   #if defined(ARMA_USE_OPENMP)
1924     {
1925     const umat boundaries = internal_gen_boundaries(X_n_cols);
1926 
1927     const uword n_threads = boundaries.n_cols;
1928 
1929     field< Mat<eT>    > t_acc_means(n_threads);
1930     field< Mat<eT>    > t_acc_dcovs(n_threads);
1931     field< Row<uword> > t_acc_hefts(n_threads);
1932 
1933     for(uword t=0; t < n_threads; ++t)
1934       {
1935       t_acc_means(t).zeros(N_dims, N_gaus);
1936       t_acc_dcovs(t).zeros(N_dims, N_gaus);
1937       t_acc_hefts(t).zeros(N_gaus);
1938       }
1939 
1940     #pragma omp parallel for schedule(static)
1941     for(uword t=0; t < n_threads; ++t)
1942       {
1943       uword* t_acc_hefts_mem = t_acc_hefts(t).memptr();
1944 
1945       const uword start_index = boundaries.at(0,t);
1946       const uword   end_index = boundaries.at(1,t);
1947 
1948       for(uword i=start_index; i <= end_index; ++i)
1949         {
1950         const eT* X_colptr = X.colptr(i);
1951 
1952         eT     min_dist = Datum<eT>::inf;
1953         uword  best_g   = 0;
1954 
1955         for(uword g=0; g<N_gaus; ++g)
1956           {
1957           const eT dist = distance<eT,dist_id>::eval(N_dims, X_colptr, means.colptr(g), mah_aux_mem);
1958 
1959           if(dist < min_dist)  { min_dist = dist;  best_g = g; }
1960           }
1961 
1962         eT* t_acc_mean = t_acc_means(t).colptr(best_g);
1963         eT* t_acc_dcov = t_acc_dcovs(t).colptr(best_g);
1964 
1965         for(uword d=0; d<N_dims; ++d)
1966           {
1967           const eT x_d = X_colptr[d];
1968 
1969           t_acc_mean[d] += x_d;
1970           t_acc_dcov[d] += x_d*x_d;
1971           }
1972 
1973         t_acc_hefts_mem[best_g]++;
1974         }
1975       }
1976 
1977     // reduction
1978     acc_means = t_acc_means(0);
1979     acc_dcovs = t_acc_dcovs(0);
1980     acc_hefts = t_acc_hefts(0);
1981 
1982     for(uword t=1; t < n_threads; ++t)
1983       {
1984       acc_means += t_acc_means(t);
1985       acc_dcovs += t_acc_dcovs(t);
1986       acc_hefts += t_acc_hefts(t);
1987       }
1988     }
1989   #else
1990     {
1991     for(uword i=0; i<X_n_cols; ++i)
1992       {
1993       const eT* X_colptr = X.colptr(i);
1994 
1995       eT     min_dist = Datum<eT>::inf;
1996       uword  best_g   = 0;
1997 
1998       for(uword g=0; g<N_gaus; ++g)
1999         {
2000         const eT dist = distance<eT,dist_id>::eval(N_dims, X_colptr, means.colptr(g), mah_aux_mem);
2001 
2002         if(dist < min_dist)  { min_dist = dist;  best_g = g; }
2003         }
2004 
2005       eT* acc_mean = acc_means.colptr(best_g);
2006       eT* acc_dcov = acc_dcovs.colptr(best_g);
2007 
2008       for(uword d=0; d<N_dims; ++d)
2009         {
2010         const eT x_d = X_colptr[d];
2011 
2012         acc_mean[d] += x_d;
2013         acc_dcov[d] += x_d*x_d;
2014         }
2015 
2016       acc_hefts_mem[best_g]++;
2017       }
2018     }
2019   #endif
2020 
2021   eT* hefts_mem = access::rw(hefts).memptr();
2022 
2023   for(uword g=0; g<N_gaus; ++g)
2024     {
2025     const eT*   acc_mean = acc_means.colptr(g);
2026     const eT*   acc_dcov = acc_dcovs.colptr(g);
2027     const uword acc_heft = acc_hefts_mem[g];
2028 
2029     eT* mean = access::rw(means).colptr(g);
2030     eT* dcov = access::rw(dcovs).colptr(g);
2031 
2032     for(uword d=0; d<N_dims; ++d)
2033       {
2034       const eT tmp = acc_mean[d] / eT(acc_heft);
2035 
2036       mean[d] = (acc_heft >= 1) ? tmp : eT(0);
2037       dcov[d] = (acc_heft >= 2) ? eT((acc_dcov[d] / eT(acc_heft)) - (tmp*tmp)) : eT(var_floor);
2038       }
2039 
2040     hefts_mem[g] = eT(acc_heft) / eT(X_n_cols);
2041     }
2042 
2043   em_fix_params(var_floor);
2044   }
2045 
2046 
2047 
2048 //! multi-threaded implementation of k-means, inspired by MapReduce
2049 template<typename eT>
2050 template<uword dist_id>
2051 inline
2052 bool
km_iterate(const Mat<eT> & X,const uword max_iter,const bool verbose,const char * signature)2053 gmm_diag<eT>::km_iterate(const Mat<eT>& X, const uword max_iter, const bool verbose, const char* signature)
2054   {
2055   arma_extra_debug_sigprint();
2056 
2057   if(verbose)
2058     {
2059     get_cout_stream().unsetf(ios::showbase);
2060     get_cout_stream().unsetf(ios::uppercase);
2061     get_cout_stream().unsetf(ios::showpos);
2062     get_cout_stream().unsetf(ios::scientific);
2063 
2064     get_cout_stream().setf(ios::right);
2065     get_cout_stream().setf(ios::fixed);
2066     }
2067 
2068   const uword X_n_cols = X.n_cols;
2069 
2070   if(X_n_cols == 0)  { return true; }
2071 
2072   const uword N_dims = means.n_rows;
2073   const uword N_gaus = means.n_cols;
2074 
2075   const eT* mah_aux_mem = mah_aux.memptr();
2076 
2077   Mat<eT>    acc_means(N_dims, N_gaus, arma_zeros_indicator());
2078   Row<uword> acc_hefts(        N_gaus, arma_zeros_indicator());
2079   Row<uword> last_indx(        N_gaus, arma_zeros_indicator());
2080 
2081   Mat<eT> new_means = means;
2082   Mat<eT> old_means = means;
2083 
2084   running_mean_scalar<eT> rs_delta;
2085 
2086   #if defined(ARMA_USE_OPENMP)
2087     const umat boundaries = internal_gen_boundaries(X_n_cols);
2088     const uword n_threads = boundaries.n_cols;
2089 
2090     field< Mat<eT>    > t_acc_means(n_threads);
2091     field< Row<uword> > t_acc_hefts(n_threads);
2092     field< Row<uword> > t_last_indx(n_threads);
2093   #else
2094     const uword n_threads = 1;
2095   #endif
2096 
2097   if(verbose)  { get_cout_stream() << signature << ": n_threads: " << n_threads  << '\n'; get_cout_stream().flush(); }
2098 
2099   for(uword iter=1; iter <= max_iter; ++iter)
2100     {
2101     #if defined(ARMA_USE_OPENMP)
2102       {
2103       for(uword t=0; t < n_threads; ++t)
2104         {
2105         t_acc_means(t).zeros(N_dims, N_gaus);
2106         t_acc_hefts(t).zeros(N_gaus);
2107         t_last_indx(t).zeros(N_gaus);
2108         }
2109 
2110       #pragma omp parallel for schedule(static)
2111       for(uword t=0; t < n_threads; ++t)
2112         {
2113         Mat<eT>& t_acc_means_t   = t_acc_means(t);
2114         uword*   t_acc_hefts_mem = t_acc_hefts(t).memptr();
2115         uword*   t_last_indx_mem = t_last_indx(t).memptr();
2116 
2117         const uword start_index = boundaries.at(0,t);
2118         const uword   end_index = boundaries.at(1,t);
2119 
2120         for(uword i=start_index; i <= end_index; ++i)
2121           {
2122           const eT* X_colptr = X.colptr(i);
2123 
2124           eT     min_dist = Datum<eT>::inf;
2125           uword  best_g   = 0;
2126 
2127           for(uword g=0; g<N_gaus; ++g)
2128             {
2129             const eT dist = distance<eT,dist_id>::eval(N_dims, X_colptr, old_means.colptr(g), mah_aux_mem);
2130 
2131             if(dist < min_dist)  { min_dist = dist;  best_g = g; }
2132             }
2133 
2134           eT* t_acc_mean = t_acc_means_t.colptr(best_g);
2135 
2136           for(uword d=0; d<N_dims; ++d)  { t_acc_mean[d] += X_colptr[d]; }
2137 
2138           t_acc_hefts_mem[best_g]++;
2139           t_last_indx_mem[best_g] = i;
2140           }
2141         }
2142 
2143       // reduction
2144 
2145       acc_means = t_acc_means(0);
2146       acc_hefts = t_acc_hefts(0);
2147 
2148       for(uword t=1; t < n_threads; ++t)
2149         {
2150         acc_means += t_acc_means(t);
2151         acc_hefts += t_acc_hefts(t);
2152         }
2153 
2154       for(uword g=0; g < N_gaus;    ++g)
2155       for(uword t=0; t < n_threads; ++t)
2156         {
2157         if( t_acc_hefts(t)(g) >= 1 )  { last_indx(g) = t_last_indx(t)(g); }
2158         }
2159       }
2160     #else
2161       {
2162       uword* acc_hefts_mem = acc_hefts.memptr();
2163       uword* last_indx_mem = last_indx.memptr();
2164 
2165       for(uword i=0; i < X_n_cols; ++i)
2166         {
2167         const eT* X_colptr = X.colptr(i);
2168 
2169         eT     min_dist = Datum<eT>::inf;
2170         uword  best_g   = 0;
2171 
2172         for(uword g=0; g<N_gaus; ++g)
2173           {
2174           const eT dist = distance<eT,dist_id>::eval(N_dims, X_colptr, old_means.colptr(g), mah_aux_mem);
2175 
2176           if(dist < min_dist)  { min_dist = dist;  best_g = g; }
2177           }
2178 
2179         eT* acc_mean = acc_means.colptr(best_g);
2180 
2181         for(uword d=0; d<N_dims; ++d)  { acc_mean[d] += X_colptr[d]; }
2182 
2183         acc_hefts_mem[best_g]++;
2184         last_indx_mem[best_g] = i;
2185         }
2186       }
2187     #endif
2188 
2189     // generate new means
2190 
2191     uword* acc_hefts_mem = acc_hefts.memptr();
2192 
2193     for(uword g=0; g < N_gaus; ++g)
2194       {
2195       const eT*   acc_mean = acc_means.colptr(g);
2196       const uword acc_heft = acc_hefts_mem[g];
2197 
2198       eT* new_mean = access::rw(new_means).colptr(g);
2199 
2200       for(uword d=0; d<N_dims; ++d)
2201         {
2202         new_mean[d] = (acc_heft >= 1) ? (acc_mean[d] / eT(acc_heft)) : eT(0);
2203         }
2204       }
2205 
2206 
2207     // heuristics to resurrect dead means
2208 
2209     const uvec dead_gs = find(acc_hefts == uword(0));
2210 
2211     if(dead_gs.n_elem > 0)
2212       {
2213       if(verbose)  { get_cout_stream() << signature << ": recovering from dead means\n"; get_cout_stream().flush(); }
2214 
2215       uword* last_indx_mem = last_indx.memptr();
2216 
2217       const uvec live_gs = sort( find(acc_hefts >= uword(2)), "descend" );
2218 
2219       if(live_gs.n_elem == 0)  { return false; }
2220 
2221       uword live_gs_count  = 0;
2222 
2223       for(uword dead_gs_count = 0; dead_gs_count < dead_gs.n_elem; ++dead_gs_count)
2224         {
2225         const uword dead_g_id = dead_gs(dead_gs_count);
2226 
2227         uword proposed_i = 0;
2228 
2229         if(live_gs_count < live_gs.n_elem)
2230           {
2231           const uword live_g_id = live_gs(live_gs_count);  ++live_gs_count;
2232 
2233           if(live_g_id == dead_g_id)  { return false; }
2234 
2235           // recover by using a sample from a known good mean
2236           proposed_i = last_indx_mem[live_g_id];
2237           }
2238         else
2239           {
2240           // recover by using a randomly seleced sample (last resort)
2241           proposed_i = as_scalar(randi<uvec>(1, distr_param(0,X_n_cols-1)));
2242           }
2243 
2244         if(proposed_i >= X_n_cols)  { return false; }
2245 
2246         new_means.col(dead_g_id) = X.col(proposed_i);
2247         }
2248       }
2249 
2250     rs_delta.reset();
2251 
2252     for(uword g=0; g < N_gaus; ++g)
2253       {
2254       rs_delta( distance<eT,dist_id>::eval(N_dims, old_means.colptr(g), new_means.colptr(g), mah_aux_mem) );
2255       }
2256 
2257     if(verbose)
2258       {
2259       get_cout_stream() << signature << ": iteration: ";
2260       get_cout_stream().unsetf(ios::scientific);
2261       get_cout_stream().setf(ios::fixed);
2262       get_cout_stream().width(std::streamsize(4));
2263       get_cout_stream() << iter;
2264       get_cout_stream() << "   delta: ";
2265       get_cout_stream().unsetf(ios::fixed);
2266       //get_cout_stream().setf(ios::scientific);
2267       get_cout_stream() << rs_delta.mean() << '\n';
2268       get_cout_stream().flush();
2269       }
2270 
2271     arma::swap(old_means, new_means);
2272 
2273     if(rs_delta.mean() <= Datum<eT>::eps)  { break; }
2274     }
2275 
2276   access::rw(means) = old_means;
2277 
2278   if(means.is_finite() == false)  { return false; }
2279 
2280   return true;
2281   }
2282 
2283 
2284 
2285 //! multi-threaded implementation of Expectation-Maximisation, inspired by MapReduce
2286 template<typename eT>
2287 inline
2288 bool
em_iterate(const Mat<eT> & X,const uword max_iter,const eT var_floor,const bool verbose)2289 gmm_diag<eT>::em_iterate(const Mat<eT>& X, const uword max_iter, const eT var_floor, const bool verbose)
2290   {
2291   arma_extra_debug_sigprint();
2292 
2293   if(X.n_cols == 0)  { return true; }
2294 
2295   const uword N_dims = means.n_rows;
2296   const uword N_gaus = means.n_cols;
2297 
2298   if(verbose)
2299     {
2300     get_cout_stream().unsetf(ios::showbase);
2301     get_cout_stream().unsetf(ios::uppercase);
2302     get_cout_stream().unsetf(ios::showpos);
2303     get_cout_stream().unsetf(ios::scientific);
2304 
2305     get_cout_stream().setf(ios::right);
2306     get_cout_stream().setf(ios::fixed);
2307     }
2308 
2309   const umat boundaries = internal_gen_boundaries(X.n_cols);
2310 
2311   const uword n_threads = boundaries.n_cols;
2312 
2313   field< Mat<eT> > t_acc_means(n_threads);
2314   field< Mat<eT> > t_acc_dcovs(n_threads);
2315 
2316   field< Col<eT> > t_acc_norm_lhoods(n_threads);
2317   field< Col<eT> > t_gaus_log_lhoods(n_threads);
2318 
2319   Col<eT>          t_progress_log_lhood(n_threads, arma_nozeros_indicator());
2320 
2321   for(uword t=0; t<n_threads; t++)
2322     {
2323     t_acc_means[t].set_size(N_dims, N_gaus);
2324     t_acc_dcovs[t].set_size(N_dims, N_gaus);
2325 
2326     t_acc_norm_lhoods[t].set_size(N_gaus);
2327     t_gaus_log_lhoods[t].set_size(N_gaus);
2328     }
2329 
2330 
2331   if(verbose)
2332     {
2333     get_cout_stream() << "gmm_diag::learn(): EM: n_threads: " << n_threads  << '\n';
2334     }
2335 
2336   eT old_avg_log_p = -Datum<eT>::inf;
2337 
2338   for(uword iter=1; iter <= max_iter; ++iter)
2339     {
2340     init_constants();
2341 
2342     em_update_params(X, boundaries, t_acc_means, t_acc_dcovs, t_acc_norm_lhoods, t_gaus_log_lhoods, t_progress_log_lhood);
2343 
2344     em_fix_params(var_floor);
2345 
2346     const eT new_avg_log_p = accu(t_progress_log_lhood) / eT(t_progress_log_lhood.n_elem);
2347 
2348     if(verbose)
2349       {
2350       get_cout_stream() << "gmm_diag::learn(): EM: iteration: ";
2351       get_cout_stream().unsetf(ios::scientific);
2352       get_cout_stream().setf(ios::fixed);
2353       get_cout_stream().width(std::streamsize(4));
2354       get_cout_stream() << iter;
2355       get_cout_stream() << "   avg_log_p: ";
2356       get_cout_stream().unsetf(ios::fixed);
2357       //get_cout_stream().setf(ios::scientific);
2358       get_cout_stream() << new_avg_log_p << '\n';
2359       get_cout_stream().flush();
2360       }
2361 
2362     if(arma_isfinite(new_avg_log_p) == false)  { return false; }
2363 
2364     if(std::abs(old_avg_log_p - new_avg_log_p) <= Datum<eT>::eps)  { break; }
2365 
2366 
2367     old_avg_log_p = new_avg_log_p;
2368     }
2369 
2370 
2371   if(any(vectorise(dcovs) <= eT(0)))  { return false; }
2372   if(means.is_finite() == false    )  { return false; }
2373   if(dcovs.is_finite() == false    )  { return false; }
2374   if(hefts.is_finite() == false    )  { return false; }
2375 
2376   return true;
2377   }
2378 
2379 
2380 
2381 
2382 template<typename eT>
2383 inline
2384 void
em_update_params(const Mat<eT> & X,const umat & boundaries,field<Mat<eT>> & t_acc_means,field<Mat<eT>> & t_acc_dcovs,field<Col<eT>> & t_acc_norm_lhoods,field<Col<eT>> & t_gaus_log_lhoods,Col<eT> & t_progress_log_lhood)2385 gmm_diag<eT>::em_update_params
2386   (
2387   const Mat<eT>&          X,
2388   const umat&             boundaries,
2389         field< Mat<eT> >& t_acc_means,
2390         field< Mat<eT> >& t_acc_dcovs,
2391         field< Col<eT> >& t_acc_norm_lhoods,
2392         field< Col<eT> >& t_gaus_log_lhoods,
2393         Col<eT>&          t_progress_log_lhood
2394   )
2395   {
2396   arma_extra_debug_sigprint();
2397 
2398   const uword n_threads = boundaries.n_cols;
2399 
2400 
2401   // em_generate_acc() is the "map" operation, which produces partial accumulators for means, diagonal covariances and hefts
2402 
2403   #if defined(ARMA_USE_OPENMP)
2404     {
2405     #pragma omp parallel for schedule(static)
2406     for(uword t=0; t<n_threads; t++)
2407       {
2408       Mat<eT>& acc_means          = t_acc_means[t];
2409       Mat<eT>& acc_dcovs          = t_acc_dcovs[t];
2410       Col<eT>& acc_norm_lhoods    = t_acc_norm_lhoods[t];
2411       Col<eT>& gaus_log_lhoods    = t_gaus_log_lhoods[t];
2412       eT&      progress_log_lhood = t_progress_log_lhood[t];
2413 
2414       em_generate_acc(X, boundaries.at(0,t), boundaries.at(1,t), acc_means, acc_dcovs, acc_norm_lhoods, gaus_log_lhoods, progress_log_lhood);
2415       }
2416     }
2417   #else
2418     {
2419     em_generate_acc(X, boundaries.at(0,0), boundaries.at(1,0), t_acc_means[0], t_acc_dcovs[0], t_acc_norm_lhoods[0], t_gaus_log_lhoods[0], t_progress_log_lhood[0]);
2420     }
2421   #endif
2422 
2423   const uword N_dims = means.n_rows;
2424   const uword N_gaus = means.n_cols;
2425 
2426   Mat<eT>& final_acc_means = t_acc_means[0];
2427   Mat<eT>& final_acc_dcovs = t_acc_dcovs[0];
2428 
2429   Col<eT>& final_acc_norm_lhoods = t_acc_norm_lhoods[0];
2430 
2431 
2432   // the "reduce" operation, which combines the partial accumulators produced by the separate threads
2433 
2434   for(uword t=1; t<n_threads; t++)
2435     {
2436     final_acc_means += t_acc_means[t];
2437     final_acc_dcovs += t_acc_dcovs[t];
2438 
2439     final_acc_norm_lhoods += t_acc_norm_lhoods[t];
2440     }
2441 
2442 
2443   eT* hefts_mem = access::rw(hefts).memptr();
2444 
2445 
2446   //// update each component without sanity checking
2447   //for(uword g=0; g < N_gaus; ++g)
2448   //  {
2449   //  const eT acc_norm_lhood = (std::max)( final_acc_norm_lhoods[g], std::numeric_limits<eT>::min() );
2450   //
2451   //  eT* mean_mem = access::rw(means).colptr(g);
2452   //  eT* dcov_mem = access::rw(dcovs).colptr(g);
2453   //
2454   //  eT* acc_mean_mem = final_acc_means.colptr(g);
2455   //  eT* acc_dcov_mem = final_acc_dcovs.colptr(g);
2456   //
2457   //  hefts_mem[g] = acc_norm_lhood / eT(X.n_cols);
2458   //
2459   //  for(uword d=0; d < N_dims; ++d)
2460   //    {
2461   //    const eT tmp = acc_mean_mem[d] / acc_norm_lhood;
2462   //
2463   //    mean_mem[d] = tmp;
2464   //    dcov_mem[d] = acc_dcov_mem[d] / acc_norm_lhood - tmp*tmp;
2465   //    }
2466   //  }
2467 
2468 
2469   // conditionally update each component;  if only a subset of the hefts was updated, em_fix_params() will sanitise them
2470   for(uword g=0; g < N_gaus; ++g)
2471     {
2472     const eT acc_norm_lhood = (std::max)( final_acc_norm_lhoods[g], std::numeric_limits<eT>::min() );
2473 
2474     if(arma_isfinite(acc_norm_lhood) == false)  { continue; }
2475 
2476     eT* acc_mean_mem = final_acc_means.colptr(g);
2477     eT* acc_dcov_mem = final_acc_dcovs.colptr(g);
2478 
2479     bool ok = true;
2480 
2481     for(uword d=0; d < N_dims; ++d)
2482       {
2483       const eT tmp1 = acc_mean_mem[d] / acc_norm_lhood;
2484       const eT tmp2 = acc_dcov_mem[d] / acc_norm_lhood - tmp1*tmp1;
2485 
2486       acc_mean_mem[d] = tmp1;
2487       acc_dcov_mem[d] = tmp2;
2488 
2489       if(arma_isfinite(tmp2) == false)  { ok = false; }
2490       }
2491 
2492 
2493     if(ok)
2494       {
2495       hefts_mem[g] = acc_norm_lhood / eT(X.n_cols);
2496 
2497       eT* mean_mem = access::rw(means).colptr(g);
2498       eT* dcov_mem = access::rw(dcovs).colptr(g);
2499 
2500       for(uword d=0; d < N_dims; ++d)
2501         {
2502         mean_mem[d] = acc_mean_mem[d];
2503         dcov_mem[d] = acc_dcov_mem[d];
2504         }
2505       }
2506     }
2507   }
2508 
2509 
2510 
2511 template<typename eT>
2512 inline
2513 void
em_generate_acc(const Mat<eT> & X,const uword start_index,const uword end_index,Mat<eT> & acc_means,Mat<eT> & acc_dcovs,Col<eT> & acc_norm_lhoods,Col<eT> & gaus_log_lhoods,eT & progress_log_lhood) const2514 gmm_diag<eT>::em_generate_acc
2515   (
2516   const Mat<eT>& X,
2517   const uword    start_index,
2518   const uword      end_index,
2519         Mat<eT>& acc_means,
2520         Mat<eT>& acc_dcovs,
2521         Col<eT>& acc_norm_lhoods,
2522         Col<eT>& gaus_log_lhoods,
2523         eT&      progress_log_lhood
2524   )
2525   const
2526   {
2527   arma_extra_debug_sigprint();
2528 
2529   progress_log_lhood = eT(0);
2530 
2531   acc_means.zeros();
2532   acc_dcovs.zeros();
2533 
2534   acc_norm_lhoods.zeros();
2535   gaus_log_lhoods.zeros();
2536 
2537   const uword N_dims = means.n_rows;
2538   const uword N_gaus = means.n_cols;
2539 
2540   const eT* log_hefts_mem       = log_hefts.memptr();
2541         eT* gaus_log_lhoods_mem = gaus_log_lhoods.memptr();
2542 
2543 
2544   for(uword i=start_index; i <= end_index; i++)
2545     {
2546     const eT* x = X.colptr(i);
2547 
2548     for(uword g=0; g < N_gaus; ++g)
2549       {
2550       gaus_log_lhoods_mem[g] = internal_scalar_log_p(x, g) + log_hefts_mem[g];
2551       }
2552 
2553     eT log_lhood_sum = gaus_log_lhoods_mem[0];
2554 
2555     for(uword g=1; g < N_gaus; ++g)
2556       {
2557       log_lhood_sum = log_add_exp(log_lhood_sum, gaus_log_lhoods_mem[g]);
2558       }
2559 
2560     progress_log_lhood += log_lhood_sum;
2561 
2562     for(uword g=0; g < N_gaus; ++g)
2563       {
2564       const eT norm_lhood = std::exp(gaus_log_lhoods_mem[g] - log_lhood_sum);
2565 
2566       acc_norm_lhoods[g] += norm_lhood;
2567 
2568       eT* acc_mean_mem = acc_means.colptr(g);
2569       eT* acc_dcov_mem = acc_dcovs.colptr(g);
2570 
2571       for(uword d=0; d < N_dims; ++d)
2572         {
2573         const eT x_d = x[d];
2574         const eT y_d = x_d * norm_lhood;
2575 
2576         acc_mean_mem[d] += y_d;
2577         acc_dcov_mem[d] += y_d * x_d;  // equivalent to x_d * x_d * norm_lhood
2578         }
2579       }
2580     }
2581 
2582   progress_log_lhood /= eT((end_index - start_index) + 1);
2583   }
2584 
2585 
2586 
2587 template<typename eT>
2588 inline
2589 void
em_fix_params(const eT var_floor)2590 gmm_diag<eT>::em_fix_params(const eT var_floor)
2591   {
2592   arma_extra_debug_sigprint();
2593 
2594   const uword N_dims = means.n_rows;
2595   const uword N_gaus = means.n_cols;
2596 
2597   const eT var_ceiling = std::numeric_limits<eT>::max();
2598 
2599   const uword dcovs_n_elem = dcovs.n_elem;
2600         eT*   dcovs_mem    = access::rw(dcovs).memptr();
2601 
2602   for(uword i=0; i < dcovs_n_elem; ++i)
2603     {
2604     eT& var_val = dcovs_mem[i];
2605 
2606          if(var_val < var_floor  )  { var_val = var_floor;   }
2607     else if(var_val > var_ceiling)  { var_val = var_ceiling; }
2608     else if(arma_isnan(var_val)  )  { var_val = eT(1);       }
2609     }
2610 
2611 
2612   eT* hefts_mem = access::rw(hefts).memptr();
2613 
2614   for(uword g1=0; g1 < N_gaus; ++g1)
2615     {
2616     if(hefts_mem[g1] > eT(0))
2617       {
2618       const eT* means_colptr_g1 = means.colptr(g1);
2619 
2620       for(uword g2=(g1+1); g2 < N_gaus; ++g2)
2621         {
2622         if( (hefts_mem[g2] > eT(0)) && (std::abs(hefts_mem[g1] - hefts_mem[g2]) <= std::numeric_limits<eT>::epsilon()) )
2623           {
2624           const eT dist = distance<eT,1>::eval(N_dims, means_colptr_g1, means.colptr(g2), means_colptr_g1);
2625 
2626           if(dist == eT(0)) { hefts_mem[g2] = eT(0); }
2627           }
2628         }
2629       }
2630     }
2631 
2632   const eT heft_floor   = std::numeric_limits<eT>::min();
2633   const eT heft_initial = eT(1) / eT(N_gaus);
2634 
2635   for(uword i=0; i < N_gaus; ++i)
2636     {
2637     eT& heft_val = hefts_mem[i];
2638 
2639          if(heft_val < heft_floor)  { heft_val = heft_floor;   }
2640     else if(heft_val > eT(1)     )  { heft_val = eT(1);        }
2641     else if(arma_isnan(heft_val) )  { heft_val = heft_initial; }
2642     }
2643 
2644   const eT heft_sum = accu(hefts);
2645 
2646   if((heft_sum < (eT(1) - Datum<eT>::eps)) || (heft_sum > (eT(1) + Datum<eT>::eps)))  { access::rw(hefts) /= heft_sum; }
2647   }
2648 
2649 
2650 } // namespace gmm_priv
2651 
2652 
2653 //! @}
2654