1 // SPDX-License-Identifier: Apache-2.0
2 //
3 // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
4 // Copyright 2008-2016 National ICT Australia (NICTA)
5 //
6 // Licensed under the Apache License, Version 2.0 (the "License");
7 // you may not use this file except in compliance with the License.
8 // You may obtain a copy of the License at
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // ------------------------------------------------------------------------
17 
18 
19 //! \addtogroup gmm_full
20 //! @{
21 
22 
23 namespace gmm_priv
24 {
25 
26 
27 template<typename eT>
28 inline
~gmm_full()29 gmm_full<eT>::~gmm_full()
30   {
31   arma_extra_debug_sigprint_this(this);
32 
33   arma_type_check(( (is_same_type<eT,float>::value == false) && (is_same_type<eT,double>::value == false) ));
34   }
35 
36 
37 
38 template<typename eT>
39 inline
gmm_full()40 gmm_full<eT>::gmm_full()
41   {
42   arma_extra_debug_sigprint_this(this);
43   }
44 
45 
46 
47 template<typename eT>
48 inline
gmm_full(const gmm_full<eT> & x)49 gmm_full<eT>::gmm_full(const gmm_full<eT>& x)
50   {
51   arma_extra_debug_sigprint_this(this);
52 
53   init(x);
54   }
55 
56 
57 
58 template<typename eT>
59 inline
60 gmm_full<eT>&
operator =(const gmm_full<eT> & x)61 gmm_full<eT>::operator=(const gmm_full<eT>& x)
62   {
63   arma_extra_debug_sigprint();
64 
65   init(x);
66 
67   return *this;
68   }
69 
70 
71 
72 template<typename eT>
73 inline
gmm_full(const gmm_diag<eT> & x)74 gmm_full<eT>::gmm_full(const gmm_diag<eT>& x)
75   {
76   arma_extra_debug_sigprint_this(this);
77 
78   init(x);
79   }
80 
81 
82 
83 template<typename eT>
84 inline
85 gmm_full<eT>&
operator =(const gmm_diag<eT> & x)86 gmm_full<eT>::operator=(const gmm_diag<eT>& x)
87   {
88   arma_extra_debug_sigprint();
89 
90   init(x);
91 
92   return *this;
93   }
94 
95 
96 
97 template<typename eT>
98 inline
gmm_full(const uword in_n_dims,const uword in_n_gaus)99 gmm_full<eT>::gmm_full(const uword in_n_dims, const uword in_n_gaus)
100   {
101   arma_extra_debug_sigprint_this(this);
102 
103   init(in_n_dims, in_n_gaus);
104   }
105 
106 
107 
108 template<typename eT>
109 inline
110 void
reset()111 gmm_full<eT>::reset()
112   {
113   arma_extra_debug_sigprint();
114 
115   init(0, 0);
116   }
117 
118 
119 
120 template<typename eT>
121 inline
122 void
reset(const uword in_n_dims,const uword in_n_gaus)123 gmm_full<eT>::reset(const uword in_n_dims, const uword in_n_gaus)
124   {
125   arma_extra_debug_sigprint();
126 
127   init(in_n_dims, in_n_gaus);
128   }
129 
130 
131 
132 template<typename eT>
133 template<typename T1, typename T2, typename T3>
134 inline
135 void
set_params(const Base<eT,T1> & in_means_expr,const BaseCube<eT,T2> & in_fcovs_expr,const Base<eT,T3> & in_hefts_expr)136 gmm_full<eT>::set_params(const Base<eT,T1>& in_means_expr, const BaseCube<eT,T2>& in_fcovs_expr, const Base<eT,T3>& in_hefts_expr)
137   {
138   arma_extra_debug_sigprint();
139 
140   const unwrap     <T1> tmp1(in_means_expr.get_ref());
141   const unwrap_cube<T2> tmp2(in_fcovs_expr.get_ref());
142   const unwrap     <T3> tmp3(in_hefts_expr.get_ref());
143 
144   const Mat <eT>& in_means = tmp1.M;
145   const Cube<eT>& in_fcovs = tmp2.M;
146   const Mat <eT>& in_hefts = tmp3.M;
147 
148   arma_debug_check
149     (
150     (in_means.n_cols != in_fcovs.n_slices) || (in_means.n_rows != in_fcovs.n_rows) || (in_fcovs.n_rows != in_fcovs.n_cols) || (in_hefts.n_cols != in_means.n_cols) || (in_hefts.n_rows != 1),
151     "gmm_full::set_params(): given parameters have inconsistent and/or wrong sizes"
152     );
153 
154   arma_debug_check( (in_means.is_finite() == false), "gmm_full::set_params(): given means have non-finite values" );
155   arma_debug_check( (in_fcovs.is_finite() == false), "gmm_full::set_params(): given fcovs have non-finite values" );
156   arma_debug_check( (in_hefts.is_finite() == false), "gmm_full::set_params(): given hefts have non-finite values" );
157 
158   for(uword g=0; g < in_fcovs.n_slices; ++g)
159     {
160     arma_debug_check( (any(diagvec(in_fcovs.slice(g)) <= eT(0))), "gmm_full::set_params(): given fcovs have negative or zero values on diagonals" );
161     }
162 
163   arma_debug_check( (any(vectorise(in_hefts) <  eT(0))), "gmm_full::set_params(): given hefts have negative values" );
164 
165   const eT s = accu(in_hefts);
166 
167   arma_debug_check( ((s < (eT(1) - eT(0.001))) || (s > (eT(1) + eT(0.001)))), "gmm_full::set_params(): sum of given hefts is not 1" );
168 
169   access::rw(means) = in_means;
170   access::rw(fcovs) = in_fcovs;
171   access::rw(hefts) = in_hefts;
172 
173   init_constants();
174   }
175 
176 
177 
178 template<typename eT>
179 template<typename T1>
180 inline
181 void
set_means(const Base<eT,T1> & in_means_expr)182 gmm_full<eT>::set_means(const Base<eT,T1>& in_means_expr)
183   {
184   arma_extra_debug_sigprint();
185 
186   const unwrap<T1> tmp(in_means_expr.get_ref());
187 
188   const Mat<eT>& in_means = tmp.M;
189 
190   arma_debug_check( (arma::size(in_means) != arma::size(means)), "gmm_full::set_means(): given means have incompatible size" );
191   arma_debug_check( (in_means.is_finite() == false),             "gmm_full::set_means(): given means have non-finite values" );
192 
193   access::rw(means) = in_means;
194   }
195 
196 
197 
198 template<typename eT>
199 template<typename T1>
200 inline
201 void
set_fcovs(const BaseCube<eT,T1> & in_fcovs_expr)202 gmm_full<eT>::set_fcovs(const BaseCube<eT,T1>& in_fcovs_expr)
203   {
204   arma_extra_debug_sigprint();
205 
206   const unwrap_cube<T1> tmp(in_fcovs_expr.get_ref());
207 
208   const Cube<eT>& in_fcovs = tmp.M;
209 
210   arma_debug_check( (arma::size(in_fcovs) != arma::size(fcovs)), "gmm_full::set_fcovs(): given fcovs have incompatible size" );
211   arma_debug_check( (in_fcovs.is_finite() == false),             "gmm_full::set_fcovs(): given fcovs have non-finite values" );
212 
213   for(uword i=0; i < in_fcovs.n_slices; ++i)
214     {
215     arma_debug_check( (any(diagvec(in_fcovs.slice(i)) <= eT(0))), "gmm_full::set_fcovs(): given fcovs have negative or zero values on diagonals" );
216     }
217 
218   access::rw(fcovs) = in_fcovs;
219 
220   init_constants();
221   }
222 
223 
224 
225 template<typename eT>
226 template<typename T1>
227 inline
228 void
set_hefts(const Base<eT,T1> & in_hefts_expr)229 gmm_full<eT>::set_hefts(const Base<eT,T1>& in_hefts_expr)
230   {
231   arma_extra_debug_sigprint();
232 
233   const unwrap<T1> tmp(in_hefts_expr.get_ref());
234 
235   const Mat<eT>& in_hefts = tmp.M;
236 
237   arma_debug_check( (arma::size(in_hefts) != arma::size(hefts)), "gmm_full::set_hefts(): given hefts have incompatible size" );
238   arma_debug_check( (in_hefts.is_finite() == false),             "gmm_full::set_hefts(): given hefts have non-finite values" );
239   arma_debug_check( (any(vectorise(in_hefts) <  eT(0))),         "gmm_full::set_hefts(): given hefts have negative values"   );
240 
241   const eT s = accu(in_hefts);
242 
243   arma_debug_check( ((s < (eT(1) - eT(0.001))) || (s > (eT(1) + eT(0.001)))), "gmm_full::set_hefts(): sum of given hefts is not 1" );
244 
245   // make sure all hefts are positive and non-zero
246 
247   const eT* in_hefts_mem = in_hefts.memptr();
248         eT*    hefts_mem = access::rw(hefts).memptr();
249 
250   for(uword i=0; i < hefts.n_elem; ++i)
251     {
252     hefts_mem[i] = (std::max)( in_hefts_mem[i], std::numeric_limits<eT>::min() );
253     }
254 
255   access::rw(hefts) /= accu(hefts);
256 
257   log_hefts = log(hefts);
258   }
259 
260 
261 
262 template<typename eT>
263 inline
264 uword
n_dims() const265 gmm_full<eT>::n_dims() const
266   {
267   return means.n_rows;
268   }
269 
270 
271 
272 template<typename eT>
273 inline
274 uword
n_gaus() const275 gmm_full<eT>::n_gaus() const
276   {
277   return means.n_cols;
278   }
279 
280 
281 
282 template<typename eT>
283 inline
284 bool
load(const std::string name)285 gmm_full<eT>::load(const std::string name)
286   {
287   arma_extra_debug_sigprint();
288 
289   field< Mat<eT> > storage;
290 
291   bool status = storage.load(name, arma_binary);
292 
293   if( (status == false) || (storage.n_elem < 2) )
294     {
295     reset();
296     arma_debug_warn_level(3, "gmm_full::load(): problem with loading or incompatible format");
297     return false;
298     }
299 
300   uword count = 0;
301 
302   const Mat<eT>& storage_means = storage(count);  ++count;
303   const Mat<eT>& storage_hefts = storage(count);  ++count;
304 
305   const uword N_dims = storage_means.n_rows;
306   const uword N_gaus = storage_means.n_cols;
307 
308   if( (storage.n_elem != (N_gaus + 2)) || (storage_hefts.n_rows != 1) || (storage_hefts.n_cols != N_gaus) )
309     {
310     reset();
311     arma_debug_warn_level(3, "gmm_full::load(): incompatible format");
312     return false;
313     }
314 
315   reset(N_dims, N_gaus);
316 
317   access::rw(means) = storage_means;
318   access::rw(hefts) = storage_hefts;
319 
320   for(uword g=0; g < N_gaus; ++g)
321     {
322     const Mat<eT>& storage_fcov = storage(count);  ++count;
323 
324     if( (storage_fcov.n_rows != N_dims) || (storage_fcov.n_cols != N_dims) )
325       {
326       reset();
327       arma_debug_warn_level(3, "gmm_full::load(): incompatible format");
328       return false;
329       }
330 
331     access::rw(fcovs).slice(g) = storage_fcov;
332     }
333 
334   init_constants();
335 
336   return true;
337   }
338 
339 
340 
341 template<typename eT>
342 inline
343 bool
save(const std::string name) const344 gmm_full<eT>::save(const std::string name) const
345   {
346   arma_extra_debug_sigprint();
347 
348   const uword N_gaus = means.n_cols;
349 
350   field< Mat<eT> > storage(2 + N_gaus);
351 
352   uword count = 0;
353 
354   storage(count) = means;  ++count;
355   storage(count) = hefts;  ++count;
356 
357   for(uword g=0; g < N_gaus; ++g)
358     {
359     storage(count) = fcovs.slice(g);  ++count;
360     }
361 
362   const bool status = storage.save(name, arma_binary);
363 
364   return status;
365   }
366 
367 
368 
369 template<typename eT>
370 inline
371 Col<eT>
generate() const372 gmm_full<eT>::generate() const
373   {
374   arma_extra_debug_sigprint();
375 
376   const uword N_dims = means.n_rows;
377   const uword N_gaus = means.n_cols;
378 
379   Col<eT> out( (N_gaus > 0) ? N_dims : uword(0), arma_nozeros_indicator() );
380   Col<eT> tmp( (N_gaus > 0) ? N_dims : uword(0), fill::randn              );
381 
382   if(N_gaus > 0)
383     {
384     const double val = randu<double>();
385 
386     double csum    = double(0);
387     uword  gaus_id = 0;
388 
389     for(uword j=0; j < N_gaus; ++j)
390       {
391       csum += hefts[j];
392 
393       if(val <= csum)  { gaus_id = j; break; }
394       }
395 
396     out  = chol_fcovs.slice(gaus_id) * tmp;
397     out += means.col(gaus_id);
398     }
399 
400   return out;
401   }
402 
403 
404 
405 template<typename eT>
406 inline
407 Mat<eT>
generate(const uword N_vec) const408 gmm_full<eT>::generate(const uword N_vec) const
409   {
410   arma_extra_debug_sigprint();
411 
412   const uword N_dims = means.n_rows;
413   const uword N_gaus = means.n_cols;
414 
415   Mat<eT> out( ( (N_gaus > 0) ? N_dims : uword(0) ), N_vec, arma_nozeros_indicator() );
416   Mat<eT> tmp( ( (N_gaus > 0) ? N_dims : uword(0) ), N_vec, fill::randn              );
417 
418   if(N_gaus > 0)
419     {
420     const eT* hefts_mem = hefts.memptr();
421 
422     for(uword i=0; i < N_vec; ++i)
423       {
424       const double val = randu<double>();
425 
426       double csum    = double(0);
427       uword  gaus_id = 0;
428 
429       for(uword j=0; j < N_gaus; ++j)
430         {
431         csum += hefts_mem[j];
432 
433         if(val <= csum)  { gaus_id = j; break; }
434         }
435 
436       Col<eT> out_vec(out.colptr(i), N_dims, false, true);
437       Col<eT> tmp_vec(tmp.colptr(i), N_dims, false, true);
438 
439       out_vec  = chol_fcovs.slice(gaus_id) * tmp_vec;
440       out_vec += means.col(gaus_id);
441       }
442     }
443 
444   return out;
445   }
446 
447 
448 
449 template<typename eT>
450 template<typename T1>
451 inline
452 eT
log_p(const T1 & expr,const gmm_empty_arg & junk1,typename enable_if<((is_arma_type<T1>::value)&& (resolves_to_colvector<T1>::value==true))>::result * junk2) const453 gmm_full<eT>::log_p(const T1& expr, const gmm_empty_arg& junk1, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == true))>::result* junk2) const
454   {
455   arma_extra_debug_sigprint();
456   arma_ignore(junk1);
457   arma_ignore(junk2);
458 
459   const uword N_dims = means.n_rows;
460 
461   const quasi_unwrap<T1> U(expr);
462 
463   arma_debug_check( (U.M.n_rows != N_dims), "gmm_full::log_p(): incompatible dimensions" );
464 
465   return internal_scalar_log_p( U.M.memptr() );
466   }
467 
468 
469 
470 template<typename eT>
471 template<typename T1>
472 inline
473 eT
log_p(const T1 & expr,const uword gaus_id,typename enable_if<((is_arma_type<T1>::value)&& (resolves_to_colvector<T1>::value==true))>::result * junk2) const474 gmm_full<eT>::log_p(const T1& expr, const uword gaus_id, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == true))>::result* junk2) const
475   {
476   arma_extra_debug_sigprint();
477   arma_ignore(junk2);
478 
479   const uword N_dims = means.n_rows;
480 
481   const quasi_unwrap<T1> U(expr);
482 
483   arma_debug_check( (U.M.n_rows != N_dims),    "gmm_full::log_p(): incompatible dimensions"            );
484   arma_debug_check( (gaus_id >= means.n_cols), "gmm_full::log_p(): specified gaussian is out of range" );
485 
486   return internal_scalar_log_p( U.M.memptr(), gaus_id );
487   }
488 
489 
490 
491 template<typename eT>
492 template<typename T1>
493 inline
494 Row<eT>
log_p(const T1 & expr,const gmm_empty_arg & junk1,typename enable_if<((is_arma_type<T1>::value)&& (resolves_to_colvector<T1>::value==false))>::result * junk2) const495 gmm_full<eT>::log_p(const T1& expr, const gmm_empty_arg& junk1, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == false))>::result* junk2) const
496   {
497   arma_extra_debug_sigprint();
498   arma_ignore(junk1);
499   arma_ignore(junk2);
500 
501   const quasi_unwrap<T1> tmp(expr);
502 
503   const Mat<eT>& X = tmp.M;
504 
505   return internal_vec_log_p(X);
506   }
507 
508 
509 
510 template<typename eT>
511 template<typename T1>
512 inline
513 Row<eT>
log_p(const T1 & expr,const uword gaus_id,typename enable_if<((is_arma_type<T1>::value)&& (resolves_to_colvector<T1>::value==false))>::result * junk2) const514 gmm_full<eT>::log_p(const T1& expr, const uword gaus_id, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == false))>::result* junk2) const
515   {
516   arma_extra_debug_sigprint();
517   arma_ignore(junk2);
518 
519   const quasi_unwrap<T1> tmp(expr);
520 
521   const Mat<eT>& X = tmp.M;
522 
523   return internal_vec_log_p(X, gaus_id);
524   }
525 
526 
527 
528 template<typename eT>
529 template<typename T1>
530 inline
531 eT
sum_log_p(const Base<eT,T1> & expr) const532 gmm_full<eT>::sum_log_p(const Base<eT,T1>& expr) const
533   {
534   arma_extra_debug_sigprint();
535 
536   const quasi_unwrap<T1> tmp(expr.get_ref());
537 
538   const Mat<eT>& X = tmp.M;
539 
540   return internal_sum_log_p(X);
541   }
542 
543 
544 
545 template<typename eT>
546 template<typename T1>
547 inline
548 eT
sum_log_p(const Base<eT,T1> & expr,const uword gaus_id) const549 gmm_full<eT>::sum_log_p(const Base<eT,T1>& expr, const uword gaus_id) const
550   {
551   arma_extra_debug_sigprint();
552 
553   const quasi_unwrap<T1> tmp(expr.get_ref());
554 
555   const Mat<eT>& X = tmp.M;
556 
557   return internal_sum_log_p(X, gaus_id);
558   }
559 
560 
561 
562 template<typename eT>
563 template<typename T1>
564 inline
565 eT
avg_log_p(const Base<eT,T1> & expr) const566 gmm_full<eT>::avg_log_p(const Base<eT,T1>& expr) const
567   {
568   arma_extra_debug_sigprint();
569 
570   const quasi_unwrap<T1> tmp(expr.get_ref());
571 
572   const Mat<eT>& X = tmp.M;
573 
574   return internal_avg_log_p(X);
575   }
576 
577 
578 
579 template<typename eT>
580 template<typename T1>
581 inline
582 eT
avg_log_p(const Base<eT,T1> & expr,const uword gaus_id) const583 gmm_full<eT>::avg_log_p(const Base<eT,T1>& expr, const uword gaus_id) const
584   {
585   arma_extra_debug_sigprint();
586 
587   const quasi_unwrap<T1> tmp(expr.get_ref());
588 
589   const Mat<eT>& X = tmp.M;
590 
591   return internal_avg_log_p(X, gaus_id);
592   }
593 
594 
595 
596 template<typename eT>
597 template<typename T1>
598 inline
599 uword
assign(const T1 & expr,const gmm_dist_mode & dist,typename enable_if<((is_arma_type<T1>::value)&& (resolves_to_colvector<T1>::value==true))>::result * junk) const600 gmm_full<eT>::assign(const T1& expr, const gmm_dist_mode& dist, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == true))>::result* junk) const
601   {
602   arma_extra_debug_sigprint();
603   arma_ignore(junk);
604 
605   const quasi_unwrap<T1> tmp(expr);
606 
607   const Mat<eT>& X = tmp.M;
608 
609   return internal_scalar_assign(X, dist);
610   }
611 
612 
613 
614 template<typename eT>
615 template<typename T1>
616 inline
617 urowvec
assign(const T1 & expr,const gmm_dist_mode & dist,typename enable_if<((is_arma_type<T1>::value)&& (resolves_to_colvector<T1>::value==false))>::result * junk) const618 gmm_full<eT>::assign(const T1& expr, const gmm_dist_mode& dist, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == false))>::result* junk) const
619   {
620   arma_extra_debug_sigprint();
621   arma_ignore(junk);
622 
623   urowvec out;
624 
625   const quasi_unwrap<T1> tmp(expr);
626 
627   const Mat<eT>& X = tmp.M;
628 
629   internal_vec_assign(out, X, dist);
630 
631   return out;
632   }
633 
634 
635 
636 template<typename eT>
637 template<typename T1>
638 inline
639 urowvec
raw_hist(const Base<eT,T1> & expr,const gmm_dist_mode & dist_mode) const640 gmm_full<eT>::raw_hist(const Base<eT,T1>& expr, const gmm_dist_mode& dist_mode) const
641   {
642   arma_extra_debug_sigprint();
643 
644   const unwrap<T1>   tmp(expr.get_ref());
645   const Mat<eT>& X = tmp.M;
646 
647   arma_debug_check( (X.n_rows != means.n_rows), "gmm_full::raw_hist(): incompatible dimensions" );
648 
649   arma_debug_check( ((dist_mode != eucl_dist) && (dist_mode != prob_dist)), "gmm_full::raw_hist(): unsupported distance mode" );
650 
651   urowvec hist;
652 
653   internal_raw_hist(hist, X, dist_mode);
654 
655   return hist;
656   }
657 
658 
659 
660 template<typename eT>
661 template<typename T1>
662 inline
663 Row<eT>
norm_hist(const Base<eT,T1> & expr,const gmm_dist_mode & dist_mode) const664 gmm_full<eT>::norm_hist(const Base<eT,T1>& expr, const gmm_dist_mode& dist_mode) const
665   {
666   arma_extra_debug_sigprint();
667 
668   const unwrap<T1>   tmp(expr.get_ref());
669   const Mat<eT>& X = tmp.M;
670 
671   arma_debug_check( (X.n_rows != means.n_rows), "gmm_full::norm_hist(): incompatible dimensions" );
672 
673   arma_debug_check( ((dist_mode != eucl_dist) && (dist_mode != prob_dist)), "gmm_full::norm_hist(): unsupported distance mode" );
674 
675   urowvec hist;
676 
677   internal_raw_hist(hist, X, dist_mode);
678 
679   const uword  hist_n_elem = hist.n_elem;
680   const uword* hist_mem    = hist.memptr();
681 
682   eT acc = eT(0);
683   for(uword i=0; i<hist_n_elem; ++i)  { acc += eT(hist_mem[i]); }
684 
685   if(acc == eT(0))  { acc = eT(1); }
686 
687   Row<eT> out(hist_n_elem, arma_nozeros_indicator());
688 
689   eT* out_mem = out.memptr();
690 
691   for(uword i=0; i<hist_n_elem; ++i)  { out_mem[i] = eT(hist_mem[i]) / acc; }
692 
693   return out;
694   }
695 
696 
697 
698 template<typename eT>
699 template<typename T1>
700 inline
701 bool
learn(const Base<eT,T1> & data,const uword N_gaus,const gmm_dist_mode & dist_mode,const gmm_seed_mode & seed_mode,const uword km_iter,const uword em_iter,const eT var_floor,const bool print_mode)702 gmm_full<eT>::learn
703   (
704   const Base<eT,T1>&   data,
705   const uword          N_gaus,
706   const gmm_dist_mode& dist_mode,
707   const gmm_seed_mode& seed_mode,
708   const uword          km_iter,
709   const uword          em_iter,
710   const eT             var_floor,
711   const bool           print_mode
712   )
713   {
714   arma_extra_debug_sigprint();
715 
716   const bool dist_mode_ok = (dist_mode == eucl_dist) || (dist_mode == maha_dist);
717 
718   const bool seed_mode_ok = \
719        (seed_mode == keep_existing)
720     || (seed_mode == static_subset)
721     || (seed_mode == static_spread)
722     || (seed_mode == random_subset)
723     || (seed_mode == random_spread);
724 
725   arma_debug_check( (dist_mode_ok == false), "gmm_full::learn(): dist_mode must be eucl_dist or maha_dist" );
726   arma_debug_check( (seed_mode_ok == false), "gmm_full::learn(): unknown seed_mode"                        );
727   arma_debug_check( (var_floor < eT(0)    ), "gmm_full::learn(): variance floor is negative"               );
728 
729   const unwrap<T1>   tmp_X(data.get_ref());
730   const Mat<eT>& X = tmp_X.M;
731 
732   if(X.is_empty()          )  { arma_debug_warn_level(3, "gmm_full::learn(): given matrix is empty"             ); return false; }
733   if(X.is_finite() == false)  { arma_debug_warn_level(3, "gmm_full::learn(): given matrix has non-finite values"); return false; }
734 
735   if(N_gaus == 0)  { reset(); return true; }
736 
737   if(dist_mode == maha_dist)
738     {
739     mah_aux = var(X,1,1);
740 
741     const uword mah_aux_n_elem = mah_aux.n_elem;
742           eT*   mah_aux_mem    = mah_aux.memptr();
743 
744     for(uword i=0; i < mah_aux_n_elem; ++i)
745       {
746       const eT val = mah_aux_mem[i];
747 
748       mah_aux_mem[i] = ((val != eT(0)) && arma_isfinite(val)) ? eT(1) / val : eT(1);
749       }
750     }
751 
752 
753   // copy current model, in case of failure by k-means and/or EM
754 
755   const gmm_full<eT> orig = (*this);
756 
757 
758   // initial means
759 
760   if(seed_mode == keep_existing)
761     {
762     if(means.is_empty()        )  { arma_debug_warn_level(3, "gmm_full::learn(): no existing means"      ); return false; }
763     if(X.n_rows != means.n_rows)  { arma_debug_warn_level(3, "gmm_full::learn(): dimensionality mismatch"); return false; }
764 
765     // TODO: also check for number of vectors?
766     }
767   else
768     {
769     if(X.n_cols < N_gaus)  { arma_debug_warn_level(3, "gmm_full::learn(): number of vectors is less than number of gaussians"); return false; }
770 
771     reset(X.n_rows, N_gaus);
772 
773     if(print_mode)  { get_cout_stream() << "gmm_full::learn(): generating initial means\n"; get_cout_stream().flush(); }
774 
775          if(dist_mode == eucl_dist)  { generate_initial_means<1>(X, seed_mode); }
776     else if(dist_mode == maha_dist)  { generate_initial_means<2>(X, seed_mode); }
777     }
778 
779 
780   // k-means
781 
782   if(km_iter > 0)
783     {
784     const arma_ostream_state stream_state(get_cout_stream());
785 
786     bool status = false;
787 
788          if(dist_mode == eucl_dist)  { status = km_iterate<1>(X, km_iter, print_mode); }
789     else if(dist_mode == maha_dist)  { status = km_iterate<2>(X, km_iter, print_mode); }
790 
791     stream_state.restore(get_cout_stream());
792 
793     if(status == false)  { arma_debug_warn_level(3, "gmm_full::learn(): k-means algorithm failed; not enough data, or too many gaussians requested"); init(orig); return false; }
794     }
795 
796 
797   // initial fcovs
798 
799   const eT var_floor_actual = (eT(var_floor) > eT(0)) ? eT(var_floor) : std::numeric_limits<eT>::min();
800 
801   if(seed_mode != keep_existing)
802     {
803     if(print_mode)  { get_cout_stream() << "gmm_full::learn(): generating initial covariances\n"; get_cout_stream().flush(); }
804 
805          if(dist_mode == eucl_dist)  { generate_initial_params<1>(X, var_floor_actual); }
806     else if(dist_mode == maha_dist)  { generate_initial_params<2>(X, var_floor_actual); }
807     }
808 
809 
810   // EM algorithm
811 
812   if(em_iter > 0)
813     {
814     const arma_ostream_state stream_state(get_cout_stream());
815 
816     const bool status = em_iterate(X, em_iter, var_floor_actual, print_mode);
817 
818     stream_state.restore(get_cout_stream());
819 
820     if(status == false)  { arma_debug_warn_level(3, "gmm_full::learn(): EM algorithm failed"); init(orig); return false; }
821     }
822 
823   mah_aux.reset();
824 
825   init_constants();
826 
827   return true;
828   }
829 
830 
831 
832 //
833 //
834 //
835 
836 
837 
838 template<typename eT>
839 inline
840 void
init(const gmm_full<eT> & x)841 gmm_full<eT>::init(const gmm_full<eT>& x)
842   {
843   arma_extra_debug_sigprint();
844 
845   gmm_full<eT>& t = *this;
846 
847   if(&t != &x)
848     {
849     access::rw(t.means) = x.means;
850     access::rw(t.fcovs) = x.fcovs;
851     access::rw(t.hefts) = x.hefts;
852 
853     init_constants();
854     }
855   }
856 
857 
858 
859 template<typename eT>
860 inline
861 void
init(const gmm_diag<eT> & x)862 gmm_full<eT>::init(const gmm_diag<eT>& x)
863   {
864   arma_extra_debug_sigprint();
865 
866   access::rw(hefts) = x.hefts;
867   access::rw(means) = x.means;
868 
869   const uword N_dims = x.means.n_rows;
870   const uword N_gaus = x.means.n_cols;
871 
872   access::rw(fcovs).zeros(N_dims,N_dims,N_gaus);
873 
874   for(uword g=0; g < N_gaus; ++g)
875     {
876     Mat<eT>& fcov = access::rw(fcovs).slice(g);
877 
878     const eT* dcov_mem = x.dcovs.colptr(g);
879 
880     for(uword d=0; d < N_dims; ++d)
881       {
882       fcov.at(d,d) = dcov_mem[d];
883       }
884     }
885 
886   init_constants();
887   }
888 
889 
890 
891 template<typename eT>
892 inline
893 void
init(const uword in_n_dims,const uword in_n_gaus)894 gmm_full<eT>::init(const uword in_n_dims, const uword in_n_gaus)
895   {
896   arma_extra_debug_sigprint();
897 
898   access::rw(means).zeros(in_n_dims, in_n_gaus);
899 
900   access::rw(fcovs).zeros(in_n_dims, in_n_dims, in_n_gaus);
901 
902   for(uword g=0; g < in_n_gaus; ++g)
903     {
904     access::rw(fcovs).slice(g).diag().ones();
905     }
906 
907   access::rw(hefts).set_size(in_n_gaus);
908   access::rw(hefts).fill(eT(1) / eT(in_n_gaus));
909 
910   init_constants();
911   }
912 
913 
914 
915 template<typename eT>
916 inline
917 void
init_constants(const bool calc_chol)918 gmm_full<eT>::init_constants(const bool calc_chol)
919   {
920   arma_extra_debug_sigprint();
921 
922   const uword N_dims = means.n_rows;
923   const uword N_gaus = means.n_cols;
924 
925   const eT tmp = (eT(N_dims)/eT(2)) * std::log(eT(2) * Datum<eT>::pi);
926 
927   //
928 
929   inv_fcovs.copy_size(fcovs);
930   log_det_etc.set_size(N_gaus);
931 
932   Mat<eT> tmp_inv;
933 
934   for(uword g=0; g < N_gaus; ++g)
935     {
936     const Mat<eT>&     fcov =      fcovs.slice(g);
937           Mat<eT>& inv_fcov =  inv_fcovs.slice(g);
938 
939   //const bool inv_ok = auxlib::inv(tmp_inv, fcov);
940     const bool inv_ok = auxlib::inv_sympd(tmp_inv, fcov);
941 
942     eT log_det_val  = eT(0);
943     eT log_det_sign = eT(0);
944 
945     const bool log_det_status = log_det(log_det_val, log_det_sign, fcov);
946 
947     const bool log_det_ok = ( log_det_status && (arma_isfinite(log_det_val)) && (log_det_sign > eT(0)) );
948 
949     if(inv_ok && log_det_ok)
950       {
951       inv_fcov = tmp_inv;
952       }
953     else
954       {
955       // last resort: treat the covariance matrix as diagonal
956 
957       inv_fcov.zeros();
958 
959       log_det_val = eT(0);
960 
961       for(uword d=0; d < N_dims; ++d)
962         {
963         const eT sanitised_val = (std::max)( eT(fcov.at(d,d)), eT(std::numeric_limits<eT>::min()) );
964 
965         inv_fcov.at(d,d) = eT(1) / sanitised_val;
966 
967         log_det_val += std::log(sanitised_val);
968         }
969       }
970 
971     log_det_etc[g] = eT(-1) * ( tmp + eT(0.5) * log_det_val );
972     }
973 
974   //
975 
976   eT* hefts_mem = access::rw(hefts).memptr();
977 
978   for(uword g=0; g < N_gaus; ++g)
979     {
980     hefts_mem[g] = (std::max)( hefts_mem[g], std::numeric_limits<eT>::min() );
981     }
982 
983   log_hefts = log(hefts);
984 
985 
986   if(calc_chol)
987     {
988     chol_fcovs.copy_size(fcovs);
989 
990     Mat<eT> tmp_chol;
991 
992     for(uword g=0; g < N_gaus; ++g)
993       {
994       const Mat<eT>& fcov      =      fcovs.slice(g);
995             Mat<eT>& chol_fcov = chol_fcovs.slice(g);
996 
997       const uword chol_layout = 1;  // indicates "lower"
998 
999       const bool chol_ok = op_chol::apply_direct(tmp_chol, fcov, chol_layout);
1000 
1001       if(chol_ok)
1002         {
1003         chol_fcov = tmp_chol;
1004         }
1005       else
1006         {
1007         // last resort: treat the covariance matrix as diagonal
1008 
1009         chol_fcov.zeros();
1010 
1011         for(uword d=0; d < N_dims; ++d)
1012           {
1013           const eT sanitised_val = (std::max)( eT(fcov.at(d,d)), eT(std::numeric_limits<eT>::min()) );
1014 
1015           chol_fcov.at(d,d) = std::sqrt(sanitised_val);
1016           }
1017         }
1018       }
1019     }
1020   }
1021 
1022 
1023 
1024 template<typename eT>
1025 inline
1026 umat
internal_gen_boundaries(const uword N) const1027 gmm_full<eT>::internal_gen_boundaries(const uword N) const
1028   {
1029   arma_extra_debug_sigprint();
1030 
1031   #if defined(ARMA_USE_OPENMP)
1032     const uword n_threads_avail = uword(omp_get_max_threads());
1033     const uword n_threads       = (n_threads_avail > 0) ? ( (n_threads_avail <= N) ? n_threads_avail : 1 ) : 1;
1034   #else
1035     static constexpr uword n_threads = 1;
1036   #endif
1037 
1038   // get_cout_stream() << "gmm_full::internal_gen_boundaries(): n_threads: " << n_threads << '\n';
1039 
1040   umat boundaries(2, n_threads, arma_nozeros_indicator());
1041 
1042   if(N > 0)
1043     {
1044     const uword chunk_size = N / n_threads;
1045 
1046     uword count = 0;
1047 
1048     for(uword t=0; t<n_threads; t++)
1049       {
1050       boundaries.at(0,t) = count;
1051 
1052       count += chunk_size;
1053 
1054       boundaries.at(1,t) = count-1;
1055       }
1056 
1057     boundaries.at(1,n_threads-1) = N - 1;
1058     }
1059   else
1060     {
1061     boundaries.zeros();
1062     }
1063 
1064   // get_cout_stream() << "gmm_full::internal_gen_boundaries(): boundaries: " << '\n' << boundaries << '\n';
1065 
1066   return boundaries;
1067   }
1068 
1069 
1070 
1071 template<typename eT>
1072 inline
1073 eT
internal_scalar_log_p(const eT * x) const1074 gmm_full<eT>::internal_scalar_log_p(const eT* x) const
1075   {
1076   arma_extra_debug_sigprint();
1077 
1078   const eT* log_hefts_mem = log_hefts.mem;
1079 
1080   const uword N_gaus = means.n_cols;
1081 
1082   if(N_gaus > 0)
1083     {
1084     eT log_sum = internal_scalar_log_p(x, 0) + log_hefts_mem[0];
1085 
1086     for(uword g=1; g < N_gaus; ++g)
1087       {
1088       const eT log_val = internal_scalar_log_p(x, g) + log_hefts_mem[g];
1089 
1090       log_sum = log_add_exp(log_sum, log_val);
1091       }
1092 
1093     return log_sum;
1094     }
1095   else
1096     {
1097     return -Datum<eT>::inf;
1098     }
1099   }
1100 
1101 
1102 
1103 template<typename eT>
1104 inline
1105 eT
internal_scalar_log_p(const eT * x,const uword g) const1106 gmm_full<eT>::internal_scalar_log_p(const eT* x, const uword g) const
1107   {
1108   arma_extra_debug_sigprint();
1109 
1110   const uword N_dims   = means.n_rows;
1111   const eT*   mean_mem = means.colptr(g);
1112 
1113   eT outer_acc = eT(0);
1114 
1115   const eT* inv_fcov_coldata = inv_fcovs.slice(g).memptr();
1116 
1117   for(uword i=0; i < N_dims; ++i)
1118     {
1119     eT inner_acc = eT(0);
1120 
1121     for(uword j=0; j < N_dims; ++j)
1122       {
1123       inner_acc += (x[j] - mean_mem[j]) * inv_fcov_coldata[j];
1124       }
1125 
1126     inv_fcov_coldata += N_dims;
1127 
1128     outer_acc += inner_acc * (x[i] - mean_mem[i]);
1129     }
1130 
1131   return eT(-0.5)*outer_acc + log_det_etc.mem[g];
1132   }
1133 
1134 
1135 
1136 template<typename eT>
1137 inline
1138 Row<eT>
internal_vec_log_p(const Mat<eT> & X) const1139 gmm_full<eT>::internal_vec_log_p(const Mat<eT>& X) const
1140   {
1141   arma_extra_debug_sigprint();
1142 
1143   const uword N_dims    = means.n_rows;
1144   const uword N_samples = X.n_cols;
1145 
1146   arma_debug_check( (X.n_rows != N_dims), "gmm_full::log_p(): incompatible dimensions" );
1147 
1148   Row<eT> out(N_samples, arma_nozeros_indicator());
1149 
1150   if(N_samples > 0)
1151     {
1152     #if defined(ARMA_USE_OPENMP)
1153       {
1154       const umat boundaries = internal_gen_boundaries(N_samples);
1155 
1156       const uword n_threads = boundaries.n_cols;
1157 
1158       #pragma omp parallel for schedule(static)
1159       for(uword t=0; t < n_threads; ++t)
1160         {
1161         const uword start_index = boundaries.at(0,t);
1162         const uword   end_index = boundaries.at(1,t);
1163 
1164         eT* out_mem = out.memptr();
1165 
1166         for(uword i=start_index; i <= end_index; ++i)
1167           {
1168           out_mem[i] = internal_scalar_log_p( X.colptr(i) );
1169           }
1170         }
1171       }
1172     #else
1173       {
1174       eT* out_mem = out.memptr();
1175 
1176       for(uword i=0; i < N_samples; ++i)
1177         {
1178         out_mem[i] = internal_scalar_log_p( X.colptr(i) );
1179         }
1180       }
1181     #endif
1182     }
1183 
1184   return out;
1185   }
1186 
1187 
1188 
1189 template<typename eT>
1190 inline
1191 Row<eT>
internal_vec_log_p(const Mat<eT> & X,const uword gaus_id) const1192 gmm_full<eT>::internal_vec_log_p(const Mat<eT>& X, const uword gaus_id) const
1193   {
1194   arma_extra_debug_sigprint();
1195 
1196   const uword N_dims    = means.n_rows;
1197   const uword N_samples = X.n_cols;
1198 
1199   arma_debug_check( (X.n_rows != N_dims),       "gmm_full::log_p(): incompatible dimensions"            );
1200   arma_debug_check( (gaus_id  >= means.n_cols), "gmm_full::log_p(): specified gaussian is out of range" );
1201 
1202   Row<eT> out(N_samples, arma_nozeros_indicator());
1203 
1204   if(N_samples > 0)
1205     {
1206     #if defined(ARMA_USE_OPENMP)
1207       {
1208       const umat boundaries = internal_gen_boundaries(N_samples);
1209 
1210       const uword n_threads = boundaries.n_cols;
1211 
1212       #pragma omp parallel for schedule(static)
1213       for(uword t=0; t < n_threads; ++t)
1214         {
1215         const uword start_index = boundaries.at(0,t);
1216         const uword   end_index = boundaries.at(1,t);
1217 
1218         eT* out_mem = out.memptr();
1219 
1220         for(uword i=start_index; i <= end_index; ++i)
1221           {
1222           out_mem[i] = internal_scalar_log_p( X.colptr(i), gaus_id );
1223           }
1224         }
1225       }
1226     #else
1227       {
1228       eT* out_mem = out.memptr();
1229 
1230       for(uword i=0; i < N_samples; ++i)
1231         {
1232         out_mem[i] = internal_scalar_log_p( X.colptr(i), gaus_id );
1233         }
1234       }
1235     #endif
1236     }
1237 
1238   return out;
1239   }
1240 
1241 
1242 
1243 template<typename eT>
1244 inline
1245 eT
internal_sum_log_p(const Mat<eT> & X) const1246 gmm_full<eT>::internal_sum_log_p(const Mat<eT>& X) const
1247   {
1248   arma_extra_debug_sigprint();
1249 
1250   arma_debug_check( (X.n_rows != means.n_rows), "gmm_full::sum_log_p(): incompatible dimensions" );
1251 
1252   const uword N = X.n_cols;
1253 
1254   if(N == 0)  { return (-Datum<eT>::inf); }
1255 
1256 
1257   #if defined(ARMA_USE_OPENMP)
1258     {
1259     const umat boundaries = internal_gen_boundaries(N);
1260 
1261     const uword n_threads = boundaries.n_cols;
1262 
1263     Col<eT> t_accs(n_threads, arma_zeros_indicator());
1264 
1265     #pragma omp parallel for schedule(static)
1266     for(uword t=0; t < n_threads; ++t)
1267       {
1268       const uword start_index = boundaries.at(0,t);
1269       const uword   end_index = boundaries.at(1,t);
1270 
1271       eT t_acc = eT(0);
1272 
1273       for(uword i=start_index; i <= end_index; ++i)
1274         {
1275         t_acc += internal_scalar_log_p( X.colptr(i) );
1276         }
1277 
1278       t_accs[t] = t_acc;
1279       }
1280 
1281     return eT(accu(t_accs));
1282     }
1283   #else
1284     {
1285     eT acc = eT(0);
1286 
1287     for(uword i=0; i<N; ++i)
1288       {
1289       acc += internal_scalar_log_p( X.colptr(i) );
1290       }
1291 
1292     return acc;
1293     }
1294   #endif
1295   }
1296 
1297 
1298 
1299 template<typename eT>
1300 inline
1301 eT
internal_sum_log_p(const Mat<eT> & X,const uword gaus_id) const1302 gmm_full<eT>::internal_sum_log_p(const Mat<eT>& X, const uword gaus_id) const
1303   {
1304   arma_extra_debug_sigprint();
1305 
1306   arma_debug_check( (X.n_rows != means.n_rows), "gmm_full::sum_log_p(): incompatible dimensions"            );
1307   arma_debug_check( (gaus_id  >= means.n_cols), "gmm_full::sum_log_p(): specified gaussian is out of range" );
1308 
1309   const uword N = X.n_cols;
1310 
1311   if(N == 0)  { return (-Datum<eT>::inf); }
1312 
1313 
1314   #if defined(ARMA_USE_OPENMP)
1315     {
1316     const umat boundaries = internal_gen_boundaries(N);
1317 
1318     const uword n_threads = boundaries.n_cols;
1319 
1320     Col<eT> t_accs(n_threads, arma_zeros_indicator());
1321 
1322     #pragma omp parallel for schedule(static)
1323     for(uword t=0; t < n_threads; ++t)
1324       {
1325       const uword start_index = boundaries.at(0,t);
1326       const uword   end_index = boundaries.at(1,t);
1327 
1328       eT t_acc = eT(0);
1329 
1330       for(uword i=start_index; i <= end_index; ++i)
1331         {
1332         t_acc += internal_scalar_log_p( X.colptr(i), gaus_id );
1333         }
1334 
1335       t_accs[t] = t_acc;
1336       }
1337 
1338     return eT(accu(t_accs));
1339     }
1340   #else
1341     {
1342     eT acc = eT(0);
1343 
1344     for(uword i=0; i<N; ++i)
1345       {
1346       acc += internal_scalar_log_p( X.colptr(i), gaus_id );
1347       }
1348 
1349     return acc;
1350     }
1351   #endif
1352   }
1353 
1354 
1355 
1356 template<typename eT>
1357 inline
1358 eT
internal_avg_log_p(const Mat<eT> & X) const1359 gmm_full<eT>::internal_avg_log_p(const Mat<eT>& X) const
1360   {
1361   arma_extra_debug_sigprint();
1362 
1363   const uword N_dims    = means.n_rows;
1364   const uword N_samples = X.n_cols;
1365 
1366   arma_debug_check( (X.n_rows != N_dims), "gmm_full::avg_log_p(): incompatible dimensions" );
1367 
1368   if(N_samples == 0)  { return (-Datum<eT>::inf); }
1369 
1370 
1371   #if defined(ARMA_USE_OPENMP)
1372     {
1373     const umat boundaries = internal_gen_boundaries(N_samples);
1374 
1375     const uword n_threads = boundaries.n_cols;
1376 
1377     field< running_mean_scalar<eT> > t_running_means(n_threads);
1378 
1379 
1380     #pragma omp parallel for schedule(static)
1381     for(uword t=0; t < n_threads; ++t)
1382       {
1383       const uword start_index = boundaries.at(0,t);
1384       const uword   end_index = boundaries.at(1,t);
1385 
1386       running_mean_scalar<eT>& current_running_mean = t_running_means[t];
1387 
1388       for(uword i=start_index; i <= end_index; ++i)
1389         {
1390         current_running_mean( internal_scalar_log_p( X.colptr(i) ) );
1391         }
1392       }
1393 
1394 
1395     eT avg = eT(0);
1396 
1397     for(uword t=0; t < n_threads; ++t)
1398       {
1399       running_mean_scalar<eT>& current_running_mean = t_running_means[t];
1400 
1401       const eT w = eT(current_running_mean.count()) / eT(N_samples);
1402 
1403       avg += w * current_running_mean.mean();
1404       }
1405 
1406     return avg;
1407     }
1408   #else
1409     {
1410     running_mean_scalar<eT> running_mean;
1411 
1412     for(uword i=0; i < N_samples; ++i)
1413       {
1414       running_mean( internal_scalar_log_p( X.colptr(i) ) );
1415       }
1416 
1417     return running_mean.mean();
1418     }
1419   #endif
1420   }
1421 
1422 
1423 
1424 template<typename eT>
1425 inline
1426 eT
internal_avg_log_p(const Mat<eT> & X,const uword gaus_id) const1427 gmm_full<eT>::internal_avg_log_p(const Mat<eT>& X, const uword gaus_id) const
1428   {
1429   arma_extra_debug_sigprint();
1430 
1431   const uword N_dims    = means.n_rows;
1432   const uword N_samples = X.n_cols;
1433 
1434   arma_debug_check( (X.n_rows != N_dims),       "gmm_full::avg_log_p(): incompatible dimensions"            );
1435   arma_debug_check( (gaus_id  >= means.n_cols), "gmm_full::avg_log_p(): specified gaussian is out of range" );
1436 
1437   if(N_samples == 0)  { return (-Datum<eT>::inf); }
1438 
1439 
1440   #if defined(ARMA_USE_OPENMP)
1441     {
1442     const umat boundaries = internal_gen_boundaries(N_samples);
1443 
1444     const uword n_threads = boundaries.n_cols;
1445 
1446     field< running_mean_scalar<eT> > t_running_means(n_threads);
1447 
1448 
1449     #pragma omp parallel for schedule(static)
1450     for(uword t=0; t < n_threads; ++t)
1451       {
1452       const uword start_index = boundaries.at(0,t);
1453       const uword   end_index = boundaries.at(1,t);
1454 
1455       running_mean_scalar<eT>& current_running_mean = t_running_means[t];
1456 
1457       for(uword i=start_index; i <= end_index; ++i)
1458         {
1459         current_running_mean( internal_scalar_log_p( X.colptr(i), gaus_id) );
1460         }
1461       }
1462 
1463 
1464     eT avg = eT(0);
1465 
1466     for(uword t=0; t < n_threads; ++t)
1467       {
1468       running_mean_scalar<eT>& current_running_mean = t_running_means[t];
1469 
1470       const eT w = eT(current_running_mean.count()) / eT(N_samples);
1471 
1472       avg += w * current_running_mean.mean();
1473       }
1474 
1475     return avg;
1476     }
1477   #else
1478     {
1479     running_mean_scalar<eT> running_mean;
1480 
1481     for(uword i=0; i<N_samples; ++i)
1482       {
1483       running_mean( internal_scalar_log_p( X.colptr(i), gaus_id ) );
1484       }
1485 
1486     return running_mean.mean();
1487     }
1488   #endif
1489   }
1490 
1491 
1492 
1493 template<typename eT>
1494 inline
1495 uword
internal_scalar_assign(const Mat<eT> & X,const gmm_dist_mode & dist_mode) const1496 gmm_full<eT>::internal_scalar_assign(const Mat<eT>& X, const gmm_dist_mode& dist_mode) const
1497   {
1498   arma_extra_debug_sigprint();
1499 
1500   const uword N_dims = means.n_rows;
1501   const uword N_gaus = means.n_cols;
1502 
1503   arma_debug_check( (X.n_rows != N_dims), "gmm_full::assign(): incompatible dimensions" );
1504   arma_debug_check( (N_gaus == 0),        "gmm_full::assign(): model has no means"      );
1505 
1506   const eT* X_mem = X.colptr(0);
1507 
1508   if(dist_mode == eucl_dist)
1509     {
1510     eT    best_dist = Datum<eT>::inf;
1511     uword best_g    = 0;
1512 
1513     for(uword g=0; g < N_gaus; ++g)
1514       {
1515       const eT tmp_dist = distance<eT,1>::eval(N_dims, X_mem, means.colptr(g), X_mem);
1516 
1517       if(tmp_dist <= best_dist)
1518         {
1519         best_dist = tmp_dist;
1520         best_g    = g;
1521         }
1522       }
1523 
1524     return best_g;
1525     }
1526   else
1527   if(dist_mode == prob_dist)
1528     {
1529     const eT* log_hefts_mem = log_hefts.memptr();
1530 
1531     eT    best_p = -Datum<eT>::inf;
1532     uword best_g = 0;
1533 
1534     for(uword g=0; g < N_gaus; ++g)
1535       {
1536       const eT tmp_p = internal_scalar_log_p(X_mem, g) + log_hefts_mem[g];
1537 
1538       if(tmp_p >= best_p)
1539         {
1540         best_p = tmp_p;
1541         best_g = g;
1542         }
1543       }
1544 
1545     return best_g;
1546     }
1547   else
1548     {
1549     arma_debug_check(true, "gmm_full::assign(): unsupported distance mode");
1550     }
1551 
1552   return uword(0);
1553   }
1554 
1555 
1556 
1557 template<typename eT>
1558 inline
1559 void
internal_vec_assign(urowvec & out,const Mat<eT> & X,const gmm_dist_mode & dist_mode) const1560 gmm_full<eT>::internal_vec_assign(urowvec& out, const Mat<eT>& X, const gmm_dist_mode& dist_mode) const
1561   {
1562   arma_extra_debug_sigprint();
1563 
1564   const uword N_dims = means.n_rows;
1565   const uword N_gaus = means.n_cols;
1566 
1567   arma_debug_check( (X.n_rows != N_dims), "gmm_full::assign(): incompatible dimensions" );
1568 
1569   const uword X_n_cols = (N_gaus > 0) ? X.n_cols : 0;
1570 
1571   out.set_size(1,X_n_cols);
1572 
1573   uword* out_mem = out.memptr();
1574 
1575   if(dist_mode == eucl_dist)
1576     {
1577     #if defined(ARMA_USE_OPENMP)
1578       {
1579       #pragma omp parallel for schedule(static)
1580       for(uword i=0; i<X_n_cols; ++i)
1581         {
1582         const eT* X_colptr = X.colptr(i);
1583 
1584         eT    best_dist = Datum<eT>::inf;
1585         uword best_g    = 0;
1586 
1587         for(uword g=0; g<N_gaus; ++g)
1588           {
1589           const eT tmp_dist = distance<eT,1>::eval(N_dims, X_colptr, means.colptr(g), X_colptr);
1590 
1591           if(tmp_dist <= best_dist)  { best_dist = tmp_dist; best_g = g; }
1592           }
1593 
1594         out_mem[i] = best_g;
1595         }
1596       }
1597     #else
1598       {
1599       for(uword i=0; i<X_n_cols; ++i)
1600         {
1601         const eT* X_colptr = X.colptr(i);
1602 
1603         eT    best_dist = Datum<eT>::inf;
1604         uword best_g    = 0;
1605 
1606         for(uword g=0; g<N_gaus; ++g)
1607           {
1608           const eT tmp_dist = distance<eT,1>::eval(N_dims, X_colptr, means.colptr(g), X_colptr);
1609 
1610           if(tmp_dist <= best_dist)  { best_dist = tmp_dist; best_g = g; }
1611           }
1612 
1613         out_mem[i] = best_g;
1614         }
1615       }
1616     #endif
1617     }
1618   else
1619   if(dist_mode == prob_dist)
1620     {
1621     #if defined(ARMA_USE_OPENMP)
1622       {
1623       const umat boundaries = internal_gen_boundaries(X_n_cols);
1624 
1625       const uword n_threads = boundaries.n_cols;
1626 
1627       const eT* log_hefts_mem = log_hefts.memptr();
1628 
1629       #pragma omp parallel for schedule(static)
1630       for(uword t=0; t < n_threads; ++t)
1631         {
1632         const uword start_index = boundaries.at(0,t);
1633         const uword   end_index = boundaries.at(1,t);
1634 
1635         for(uword i=start_index; i <= end_index; ++i)
1636           {
1637           const eT* X_colptr = X.colptr(i);
1638 
1639           eT    best_p = -Datum<eT>::inf;
1640           uword best_g = 0;
1641 
1642           for(uword g=0; g<N_gaus; ++g)
1643             {
1644             const eT tmp_p = internal_scalar_log_p(X_colptr, g) + log_hefts_mem[g];
1645 
1646             if(tmp_p >= best_p)  { best_p = tmp_p; best_g = g; }
1647             }
1648 
1649           out_mem[i] = best_g;
1650           }
1651         }
1652       }
1653     #else
1654       {
1655       const eT* log_hefts_mem = log_hefts.memptr();
1656 
1657       for(uword i=0; i<X_n_cols; ++i)
1658         {
1659         const eT* X_colptr = X.colptr(i);
1660 
1661         eT    best_p = -Datum<eT>::inf;
1662         uword best_g = 0;
1663 
1664         for(uword g=0; g<N_gaus; ++g)
1665           {
1666           const eT tmp_p = internal_scalar_log_p(X_colptr, g) + log_hefts_mem[g];
1667 
1668           if(tmp_p >= best_p)  { best_p = tmp_p; best_g = g; }
1669           }
1670 
1671         out_mem[i] = best_g;
1672         }
1673       }
1674     #endif
1675     }
1676   else
1677     {
1678     arma_debug_check(true, "gmm_full::assign(): unsupported distance mode");
1679     }
1680   }
1681 
1682 
1683 
1684 
1685 template<typename eT>
1686 inline
1687 void
internal_raw_hist(urowvec & hist,const Mat<eT> & X,const gmm_dist_mode & dist_mode) const1688 gmm_full<eT>::internal_raw_hist(urowvec& hist, const Mat<eT>& X, const gmm_dist_mode& dist_mode) const
1689   {
1690   arma_extra_debug_sigprint();
1691 
1692   const uword N_dims = means.n_rows;
1693   const uword N_gaus = means.n_cols;
1694 
1695   const uword X_n_cols = X.n_cols;
1696 
1697   hist.zeros(N_gaus);
1698 
1699   if(N_gaus == 0)  { return; }
1700 
1701   #if defined(ARMA_USE_OPENMP)
1702     {
1703     const umat boundaries = internal_gen_boundaries(X_n_cols);
1704 
1705     const uword n_threads = boundaries.n_cols;
1706 
1707     field<urowvec> thread_hist(n_threads);
1708 
1709     for(uword t=0; t < n_threads; ++t)  { thread_hist(t).zeros(N_gaus); }
1710 
1711 
1712     if(dist_mode == eucl_dist)
1713       {
1714       #pragma omp parallel for schedule(static)
1715       for(uword t=0; t < n_threads; ++t)
1716         {
1717         uword* thread_hist_mem = thread_hist(t).memptr();
1718 
1719         const uword start_index = boundaries.at(0,t);
1720         const uword   end_index = boundaries.at(1,t);
1721 
1722         for(uword i=start_index; i <= end_index; ++i)
1723           {
1724           const eT* X_colptr = X.colptr(i);
1725 
1726           eT    best_dist = Datum<eT>::inf;
1727           uword best_g    = 0;
1728 
1729           for(uword g=0; g < N_gaus; ++g)
1730             {
1731             const eT tmp_dist = distance<eT,1>::eval(N_dims, X_colptr, means.colptr(g), X_colptr);
1732 
1733             if(tmp_dist <= best_dist)  { best_dist = tmp_dist; best_g = g; }
1734             }
1735 
1736           thread_hist_mem[best_g]++;
1737           }
1738         }
1739       }
1740     else
1741     if(dist_mode == prob_dist)
1742       {
1743       const eT* log_hefts_mem = log_hefts.memptr();
1744 
1745       #pragma omp parallel for schedule(static)
1746       for(uword t=0; t < n_threads; ++t)
1747         {
1748         uword* thread_hist_mem = thread_hist(t).memptr();
1749 
1750         const uword start_index = boundaries.at(0,t);
1751         const uword   end_index = boundaries.at(1,t);
1752 
1753         for(uword i=start_index; i <= end_index; ++i)
1754           {
1755           const eT* X_colptr = X.colptr(i);
1756 
1757           eT    best_p = -Datum<eT>::inf;
1758           uword best_g = 0;
1759 
1760           for(uword g=0; g < N_gaus; ++g)
1761             {
1762             const eT tmp_p = internal_scalar_log_p(X_colptr, g) + log_hefts_mem[g];
1763 
1764             if(tmp_p >= best_p)  { best_p = tmp_p; best_g = g; }
1765             }
1766 
1767           thread_hist_mem[best_g]++;
1768           }
1769         }
1770       }
1771 
1772     // reduction
1773     for(uword t=0; t < n_threads; ++t)
1774       {
1775       hist += thread_hist(t);
1776       }
1777     }
1778   #else
1779     {
1780     uword* hist_mem = hist.memptr();
1781 
1782     if(dist_mode == eucl_dist)
1783       {
1784       for(uword i=0; i<X_n_cols; ++i)
1785         {
1786         const eT* X_colptr = X.colptr(i);
1787 
1788         eT    best_dist = Datum<eT>::inf;
1789         uword best_g    = 0;
1790 
1791         for(uword g=0; g < N_gaus; ++g)
1792           {
1793           const eT tmp_dist = distance<eT,1>::eval(N_dims, X_colptr, means.colptr(g), X_colptr);
1794 
1795           if(tmp_dist <= best_dist)  { best_dist = tmp_dist; best_g = g; }
1796           }
1797 
1798         hist_mem[best_g]++;
1799         }
1800       }
1801     else
1802     if(dist_mode == prob_dist)
1803       {
1804       const eT* log_hefts_mem = log_hefts.memptr();
1805 
1806       for(uword i=0; i<X_n_cols; ++i)
1807         {
1808         const eT* X_colptr = X.colptr(i);
1809 
1810         eT    best_p = -Datum<eT>::inf;
1811         uword best_g = 0;
1812 
1813         for(uword g=0; g < N_gaus; ++g)
1814           {
1815           const eT tmp_p = internal_scalar_log_p(X_colptr, g) + log_hefts_mem[g];
1816 
1817           if(tmp_p >= best_p)  { best_p = tmp_p; best_g = g; }
1818           }
1819 
1820         hist_mem[best_g]++;
1821         }
1822       }
1823     }
1824   #endif
1825   }
1826 
1827 
1828 
1829 template<typename eT>
1830 template<uword dist_id>
1831 inline
1832 void
generate_initial_means(const Mat<eT> & X,const gmm_seed_mode & seed_mode)1833 gmm_full<eT>::generate_initial_means(const Mat<eT>& X, const gmm_seed_mode& seed_mode)
1834   {
1835   arma_extra_debug_sigprint();
1836 
1837   const uword N_dims = means.n_rows;
1838   const uword N_gaus = means.n_cols;
1839 
1840   if( (seed_mode == static_subset) || (seed_mode == random_subset) )
1841     {
1842     uvec initial_indices;
1843 
1844          if(seed_mode == static_subset)  { initial_indices = linspace<uvec>(0, X.n_cols-1, N_gaus); }
1845     else if(seed_mode == random_subset)  { initial_indices = randperm<uvec>(X.n_cols, N_gaus);      }
1846 
1847     // initial_indices.print("initial_indices:");
1848 
1849     access::rw(means) = X.cols(initial_indices);
1850     }
1851   else
1852   if( (seed_mode == static_spread) || (seed_mode == random_spread) )
1853     {
1854     // going through all of the samples can be extremely time consuming;
1855     // instead, if there are enough samples, randomly choose samples with probability 0.1
1856 
1857     const bool  use_sampling = ((X.n_cols/uword(100)) > N_gaus);
1858     const uword step         = (use_sampling) ? uword(10) : uword(1);
1859 
1860     uword start_index = 0;
1861 
1862          if(seed_mode == static_spread)  { start_index = X.n_cols / 2;                                         }
1863     else if(seed_mode == random_spread)  { start_index = as_scalar(randi<uvec>(1, distr_param(0,X.n_cols-1))); }
1864 
1865     access::rw(means).col(0) = X.unsafe_col(start_index);
1866 
1867     const eT* mah_aux_mem = mah_aux.memptr();
1868 
1869     running_stat<double> rs;
1870 
1871     for(uword g=1; g < N_gaus; ++g)
1872       {
1873       eT    max_dist = eT(0);
1874       uword best_i   = uword(0);
1875       uword start_i  = uword(0);
1876 
1877       if(use_sampling)
1878         {
1879         uword start_i_proposed = uword(0);
1880 
1881         if(seed_mode == static_spread)  { start_i_proposed = g % uword(10);                               }
1882         if(seed_mode == random_spread)  { start_i_proposed = as_scalar(randi<uvec>(1, distr_param(0,9))); }
1883 
1884         if(start_i_proposed < X.n_cols)  { start_i = start_i_proposed; }
1885         }
1886 
1887 
1888       for(uword i=start_i; i < X.n_cols; i += step)
1889         {
1890         rs.reset();
1891 
1892         const eT* X_colptr = X.colptr(i);
1893 
1894         bool ignore_i = false;
1895 
1896         // find the average distance between sample i and the means so far
1897         for(uword h = 0; h < g; ++h)
1898           {
1899           const eT dist = distance<eT,dist_id>::eval(N_dims, X_colptr, means.colptr(h), mah_aux_mem);
1900 
1901           // ignore sample already selected as a mean
1902           if(dist == eT(0))  { ignore_i = true; break; }
1903           else               { rs(dist);               }
1904           }
1905 
1906         if( (rs.mean() >= max_dist) && (ignore_i == false))
1907           {
1908           max_dist = eT(rs.mean()); best_i = i;
1909           }
1910         }
1911 
1912       // set the mean to the sample that is the furthest away from the means so far
1913       access::rw(means).col(g) = X.unsafe_col(best_i);
1914       }
1915     }
1916 
1917   // get_cout_stream() << "generate_initial_means():" << '\n';
1918   // means.print();
1919   }
1920 
1921 
1922 
1923 template<typename eT>
1924 template<uword dist_id>
1925 inline
1926 void
generate_initial_params(const Mat<eT> & X,const eT var_floor)1927 gmm_full<eT>::generate_initial_params(const Mat<eT>& X, const eT var_floor)
1928   {
1929   arma_extra_debug_sigprint();
1930 
1931   const uword N_dims = means.n_rows;
1932   const uword N_gaus = means.n_cols;
1933 
1934   const eT* mah_aux_mem = mah_aux.memptr();
1935 
1936   const uword X_n_cols = X.n_cols;
1937 
1938   if(X_n_cols == 0)  { return; }
1939 
1940   // as the covariances are calculated via accumulators,
1941   // the means also need to be calculated via accumulators to ensure numerical consistency
1942 
1943   Mat<eT> acc_means(N_dims, N_gaus);
1944   Mat<eT> acc_dcovs(N_dims, N_gaus);
1945 
1946   Row<uword> acc_hefts(N_gaus, arma_zeros_indicator());
1947 
1948   uword* acc_hefts_mem = acc_hefts.memptr();
1949 
1950   #if defined(ARMA_USE_OPENMP)
1951     {
1952     const umat boundaries = internal_gen_boundaries(X_n_cols);
1953 
1954     const uword n_threads = boundaries.n_cols;
1955 
1956     field< Mat<eT>    > t_acc_means(n_threads);
1957     field< Mat<eT>    > t_acc_dcovs(n_threads);
1958     field< Row<uword> > t_acc_hefts(n_threads);
1959 
1960     for(uword t=0; t < n_threads; ++t)
1961       {
1962       t_acc_means(t).zeros(N_dims, N_gaus);
1963       t_acc_dcovs(t).zeros(N_dims, N_gaus);
1964       t_acc_hefts(t).zeros(N_gaus);
1965       }
1966 
1967     #pragma omp parallel for schedule(static)
1968     for(uword t=0; t < n_threads; ++t)
1969       {
1970       uword* t_acc_hefts_mem = t_acc_hefts(t).memptr();
1971 
1972       const uword start_index = boundaries.at(0,t);
1973       const uword   end_index = boundaries.at(1,t);
1974 
1975       for(uword i=start_index; i <= end_index; ++i)
1976         {
1977         const eT* X_colptr = X.colptr(i);
1978 
1979         eT     min_dist = Datum<eT>::inf;
1980         uword  best_g   = 0;
1981 
1982         for(uword g=0; g<N_gaus; ++g)
1983           {
1984           const eT dist = distance<eT,dist_id>::eval(N_dims, X_colptr, means.colptr(g), mah_aux_mem);
1985 
1986           if(dist < min_dist)  { min_dist = dist;  best_g = g; }
1987           }
1988 
1989         eT* t_acc_mean = t_acc_means(t).colptr(best_g);
1990         eT* t_acc_dcov = t_acc_dcovs(t).colptr(best_g);
1991 
1992         for(uword d=0; d<N_dims; ++d)
1993           {
1994           const eT x_d = X_colptr[d];
1995 
1996           t_acc_mean[d] += x_d;
1997           t_acc_dcov[d] += x_d*x_d;
1998           }
1999 
2000         t_acc_hefts_mem[best_g]++;
2001         }
2002       }
2003 
2004     // reduction
2005     acc_means = t_acc_means(0);
2006     acc_dcovs = t_acc_dcovs(0);
2007     acc_hefts = t_acc_hefts(0);
2008 
2009     for(uword t=1; t < n_threads; ++t)
2010       {
2011       acc_means += t_acc_means(t);
2012       acc_dcovs += t_acc_dcovs(t);
2013       acc_hefts += t_acc_hefts(t);
2014       }
2015     }
2016   #else
2017     {
2018     for(uword i=0; i<X_n_cols; ++i)
2019       {
2020       const eT* X_colptr = X.colptr(i);
2021 
2022       eT     min_dist = Datum<eT>::inf;
2023       uword  best_g   = 0;
2024 
2025       for(uword g=0; g<N_gaus; ++g)
2026         {
2027         const eT dist = distance<eT,dist_id>::eval(N_dims, X_colptr, means.colptr(g), mah_aux_mem);
2028 
2029         if(dist < min_dist)  { min_dist = dist;  best_g = g; }
2030         }
2031 
2032       eT* acc_mean = acc_means.colptr(best_g);
2033       eT* acc_dcov = acc_dcovs.colptr(best_g);
2034 
2035       for(uword d=0; d<N_dims; ++d)
2036         {
2037         const eT x_d = X_colptr[d];
2038 
2039         acc_mean[d] += x_d;
2040         acc_dcov[d] += x_d*x_d;
2041         }
2042 
2043       acc_hefts_mem[best_g]++;
2044       }
2045     }
2046   #endif
2047 
2048   eT* hefts_mem = access::rw(hefts).memptr();
2049 
2050   for(uword g=0; g<N_gaus; ++g)
2051     {
2052     const eT*   acc_mean = acc_means.colptr(g);
2053     const eT*   acc_dcov = acc_dcovs.colptr(g);
2054     const uword acc_heft = acc_hefts_mem[g];
2055 
2056     eT* mean = access::rw(means).colptr(g);
2057 
2058     Mat<eT>& fcov = access::rw(fcovs).slice(g);
2059     fcov.zeros();
2060 
2061     for(uword d=0; d<N_dims; ++d)
2062       {
2063       const eT tmp = acc_mean[d] / eT(acc_heft);
2064 
2065       mean[d]      = (acc_heft >= 1) ? tmp : eT(0);
2066       fcov.at(d,d) = (acc_heft >= 2) ? eT((acc_dcov[d] / eT(acc_heft)) - (tmp*tmp)) : eT(var_floor);
2067       }
2068 
2069     hefts_mem[g] = eT(acc_heft) / eT(X_n_cols);
2070     }
2071 
2072   em_fix_params(var_floor);
2073   }
2074 
2075 
2076 
2077 //! multi-threaded implementation of k-means, inspired by MapReduce
2078 template<typename eT>
2079 template<uword dist_id>
2080 inline
2081 bool
km_iterate(const Mat<eT> & X,const uword max_iter,const bool verbose)2082 gmm_full<eT>::km_iterate(const Mat<eT>& X, const uword max_iter, const bool verbose)
2083   {
2084   arma_extra_debug_sigprint();
2085 
2086   if(verbose)
2087     {
2088     get_cout_stream().unsetf(ios::showbase);
2089     get_cout_stream().unsetf(ios::uppercase);
2090     get_cout_stream().unsetf(ios::showpos);
2091     get_cout_stream().unsetf(ios::scientific);
2092 
2093     get_cout_stream().setf(ios::right);
2094     get_cout_stream().setf(ios::fixed);
2095     }
2096 
2097   const uword X_n_cols = X.n_cols;
2098 
2099   if(X_n_cols == 0)  { return true; }
2100 
2101   const uword N_dims = means.n_rows;
2102   const uword N_gaus = means.n_cols;
2103 
2104   const eT* mah_aux_mem = mah_aux.memptr();
2105 
2106   Mat<eT>    acc_means(N_dims, N_gaus, arma_zeros_indicator());
2107   Row<uword> acc_hefts(        N_gaus, arma_zeros_indicator());
2108   Row<uword> last_indx(        N_gaus, arma_zeros_indicator());
2109 
2110   Mat<eT> new_means = means;
2111   Mat<eT> old_means = means;
2112 
2113   running_mean_scalar<eT> rs_delta;
2114 
2115   #if defined(ARMA_USE_OPENMP)
2116     const umat boundaries = internal_gen_boundaries(X_n_cols);
2117     const uword n_threads = boundaries.n_cols;
2118 
2119     field< Mat<eT>    > t_acc_means(n_threads);
2120     field< Row<uword> > t_acc_hefts(n_threads);
2121     field< Row<uword> > t_last_indx(n_threads);
2122   #else
2123     const uword n_threads = 1;
2124   #endif
2125 
2126   if(verbose)  { get_cout_stream() << "gmm_full::learn(): k-means: n_threads: " << n_threads << '\n';  get_cout_stream().flush(); }
2127 
2128   for(uword iter=1; iter <= max_iter; ++iter)
2129     {
2130     #if defined(ARMA_USE_OPENMP)
2131       {
2132       for(uword t=0; t < n_threads; ++t)
2133         {
2134         t_acc_means(t).zeros(N_dims, N_gaus);
2135         t_acc_hefts(t).zeros(N_gaus);
2136         t_last_indx(t).zeros(N_gaus);
2137         }
2138 
2139       #pragma omp parallel for schedule(static)
2140       for(uword t=0; t < n_threads; ++t)
2141         {
2142         Mat<eT>& t_acc_means_t   = t_acc_means(t);
2143         uword*   t_acc_hefts_mem = t_acc_hefts(t).memptr();
2144         uword*   t_last_indx_mem = t_last_indx(t).memptr();
2145 
2146         const uword start_index = boundaries.at(0,t);
2147         const uword   end_index = boundaries.at(1,t);
2148 
2149         for(uword i=start_index; i <= end_index; ++i)
2150           {
2151           const eT* X_colptr = X.colptr(i);
2152 
2153           eT     min_dist = Datum<eT>::inf;
2154           uword  best_g   = 0;
2155 
2156           for(uword g=0; g<N_gaus; ++g)
2157             {
2158             const eT dist = distance<eT,dist_id>::eval(N_dims, X_colptr, old_means.colptr(g), mah_aux_mem);
2159 
2160             if(dist < min_dist)  { min_dist = dist;  best_g = g; }
2161             }
2162 
2163           eT* t_acc_mean = t_acc_means_t.colptr(best_g);
2164 
2165           for(uword d=0; d<N_dims; ++d)  { t_acc_mean[d] += X_colptr[d]; }
2166 
2167           t_acc_hefts_mem[best_g]++;
2168           t_last_indx_mem[best_g] = i;
2169           }
2170         }
2171 
2172       // reduction
2173 
2174       acc_means = t_acc_means(0);
2175       acc_hefts = t_acc_hefts(0);
2176 
2177       for(uword t=1; t < n_threads; ++t)
2178         {
2179         acc_means += t_acc_means(t);
2180         acc_hefts += t_acc_hefts(t);
2181         }
2182 
2183       for(uword g=0; g < N_gaus;    ++g)
2184       for(uword t=0; t < n_threads; ++t)
2185         {
2186         if( t_acc_hefts(t)(g) >= 1 )  { last_indx(g) = t_last_indx(t)(g); }
2187         }
2188       }
2189     #else
2190       {
2191       uword* acc_hefts_mem = acc_hefts.memptr();
2192       uword* last_indx_mem = last_indx.memptr();
2193 
2194       for(uword i=0; i < X_n_cols; ++i)
2195         {
2196         const eT* X_colptr = X.colptr(i);
2197 
2198         eT     min_dist = Datum<eT>::inf;
2199         uword  best_g   = 0;
2200 
2201         for(uword g=0; g<N_gaus; ++g)
2202           {
2203           const eT dist = distance<eT,dist_id>::eval(N_dims, X_colptr, old_means.colptr(g), mah_aux_mem);
2204 
2205           if(dist < min_dist)  { min_dist = dist;  best_g = g; }
2206           }
2207 
2208         eT* acc_mean = acc_means.colptr(best_g);
2209 
2210         for(uword d=0; d<N_dims; ++d)  { acc_mean[d] += X_colptr[d]; }
2211 
2212         acc_hefts_mem[best_g]++;
2213         last_indx_mem[best_g] = i;
2214         }
2215       }
2216     #endif
2217 
2218     // generate new means
2219 
2220     uword* acc_hefts_mem = acc_hefts.memptr();
2221 
2222     for(uword g=0; g < N_gaus; ++g)
2223       {
2224       const eT*   acc_mean = acc_means.colptr(g);
2225       const uword acc_heft = acc_hefts_mem[g];
2226 
2227       eT* new_mean = access::rw(new_means).colptr(g);
2228 
2229       for(uword d=0; d<N_dims; ++d)
2230         {
2231         new_mean[d] = (acc_heft >= 1) ? (acc_mean[d] / eT(acc_heft)) : eT(0);
2232         }
2233       }
2234 
2235 
2236     // heuristics to resurrect dead means
2237 
2238     const uvec dead_gs = find(acc_hefts == uword(0));
2239 
2240     if(dead_gs.n_elem > 0)
2241       {
2242       if(verbose)  { get_cout_stream() << "gmm_full::learn(): k-means: recovering from dead means\n"; get_cout_stream().flush(); }
2243 
2244       uword* last_indx_mem = last_indx.memptr();
2245 
2246       const uvec live_gs = sort( find(acc_hefts >= uword(2)), "descend" );
2247 
2248       if(live_gs.n_elem == 0)  { return false; }
2249 
2250       uword live_gs_count  = 0;
2251 
2252       for(uword dead_gs_count = 0; dead_gs_count < dead_gs.n_elem; ++dead_gs_count)
2253         {
2254         const uword dead_g_id = dead_gs(dead_gs_count);
2255 
2256         uword proposed_i = 0;
2257 
2258         if(live_gs_count < live_gs.n_elem)
2259           {
2260           const uword live_g_id = live_gs(live_gs_count);  ++live_gs_count;
2261 
2262           if(live_g_id == dead_g_id)  { return false; }
2263 
2264           // recover by using a sample from a known good mean
2265           proposed_i = last_indx_mem[live_g_id];
2266           }
2267         else
2268           {
2269           // recover by using a randomly seleced sample (last resort)
2270           proposed_i = as_scalar(randi<uvec>(1, distr_param(0,X_n_cols-1)));
2271           }
2272 
2273         if(proposed_i >= X_n_cols)  { return false; }
2274 
2275         new_means.col(dead_g_id) = X.col(proposed_i);
2276         }
2277       }
2278 
2279     rs_delta.reset();
2280 
2281     for(uword g=0; g < N_gaus; ++g)
2282       {
2283       rs_delta( distance<eT,dist_id>::eval(N_dims, old_means.colptr(g), new_means.colptr(g), mah_aux_mem) );
2284       }
2285 
2286     if(verbose)
2287       {
2288       get_cout_stream() << "gmm_full::learn(): k-means: iteration: ";
2289       get_cout_stream().unsetf(ios::scientific);
2290       get_cout_stream().setf(ios::fixed);
2291       get_cout_stream().width(std::streamsize(4));
2292       get_cout_stream() << iter;
2293       get_cout_stream() << "   delta: ";
2294       get_cout_stream().unsetf(ios::fixed);
2295       //get_cout_stream().setf(ios::scientific);
2296       get_cout_stream() << rs_delta.mean() << '\n';
2297       get_cout_stream().flush();
2298       }
2299 
2300     arma::swap(old_means, new_means);
2301 
2302     if(rs_delta.mean() <= Datum<eT>::eps)  { break; }
2303     }
2304 
2305   access::rw(means) = old_means;
2306 
2307   if(means.is_finite() == false)  { return false; }
2308 
2309   return true;
2310   }
2311 
2312 
2313 
2314 //! multi-threaded implementation of Expectation-Maximisation, inspired by MapReduce
2315 template<typename eT>
2316 inline
2317 bool
em_iterate(const Mat<eT> & X,const uword max_iter,const eT var_floor,const bool verbose)2318 gmm_full<eT>::em_iterate(const Mat<eT>& X, const uword max_iter, const eT var_floor, const bool verbose)
2319   {
2320   arma_extra_debug_sigprint();
2321 
2322   const uword N_dims = means.n_rows;
2323   const uword N_gaus = means.n_cols;
2324 
2325   if(verbose)
2326     {
2327     get_cout_stream().unsetf(ios::showbase);
2328     get_cout_stream().unsetf(ios::uppercase);
2329     get_cout_stream().unsetf(ios::showpos);
2330     get_cout_stream().unsetf(ios::scientific);
2331 
2332     get_cout_stream().setf(ios::right);
2333     get_cout_stream().setf(ios::fixed);
2334     }
2335 
2336   const umat boundaries = internal_gen_boundaries(X.n_cols);
2337 
2338   const uword n_threads = boundaries.n_cols;
2339 
2340   field<  Mat<eT> > t_acc_means(n_threads);
2341   field< Cube<eT> > t_acc_fcovs(n_threads);
2342 
2343   field< Col<eT> > t_acc_norm_lhoods(n_threads);
2344   field< Col<eT> > t_gaus_log_lhoods(n_threads);
2345 
2346   Col<eT>          t_progress_log_lhood(n_threads, arma_nozeros_indicator());
2347 
2348   for(uword t=0; t<n_threads; t++)
2349     {
2350     t_acc_means[t].set_size(N_dims, N_gaus);
2351     t_acc_fcovs[t].set_size(N_dims, N_dims, N_gaus);
2352 
2353     t_acc_norm_lhoods[t].set_size(N_gaus);
2354     t_gaus_log_lhoods[t].set_size(N_gaus);
2355     }
2356 
2357 
2358   if(verbose)
2359     {
2360     get_cout_stream() << "gmm_full::learn(): EM: n_threads: " << n_threads  << '\n';
2361     }
2362 
2363   eT old_avg_log_p = -Datum<eT>::inf;
2364 
2365   const bool calc_chol = false;
2366 
2367   for(uword iter=1; iter <= max_iter; ++iter)
2368     {
2369     init_constants(calc_chol);
2370 
2371     em_update_params(X, boundaries, t_acc_means, t_acc_fcovs, t_acc_norm_lhoods, t_gaus_log_lhoods, t_progress_log_lhood, var_floor);
2372 
2373     em_fix_params(var_floor);
2374 
2375     const eT new_avg_log_p = accu(t_progress_log_lhood) / eT(t_progress_log_lhood.n_elem);
2376 
2377     if(verbose)
2378       {
2379       get_cout_stream() << "gmm_full::learn(): EM: iteration: ";
2380       get_cout_stream().unsetf(ios::scientific);
2381       get_cout_stream().setf(ios::fixed);
2382       get_cout_stream().width(std::streamsize(4));
2383       get_cout_stream() << iter;
2384       get_cout_stream() << "   avg_log_p: ";
2385       get_cout_stream().unsetf(ios::fixed);
2386       //get_cout_stream().setf(ios::scientific);
2387       get_cout_stream() << new_avg_log_p << '\n';
2388       get_cout_stream().flush();
2389       }
2390 
2391     if(arma_isfinite(new_avg_log_p) == false)  { return false; }
2392 
2393     if(std::abs(old_avg_log_p - new_avg_log_p) <= Datum<eT>::eps)  { break; }
2394 
2395 
2396     old_avg_log_p = new_avg_log_p;
2397     }
2398 
2399 
2400   for(uword g=0; g < N_gaus; ++g)
2401     {
2402     const Mat<eT>& fcov = fcovs.slice(g);
2403 
2404     if(any(vectorise(fcov.diag()) <= eT(0)))  { return false; }
2405     }
2406 
2407   if(means.is_finite() == false)  { return false; }
2408   if(fcovs.is_finite() == false)  { return false; }
2409   if(hefts.is_finite() == false)  { return false; }
2410 
2411   return true;
2412   }
2413 
2414 
2415 
2416 
2417 template<typename eT>
2418 inline
2419 void
em_update_params(const Mat<eT> & X,const umat & boundaries,field<Mat<eT>> & t_acc_means,field<Cube<eT>> & t_acc_fcovs,field<Col<eT>> & t_acc_norm_lhoods,field<Col<eT>> & t_gaus_log_lhoods,Col<eT> & t_progress_log_lhood,const eT var_floor)2420 gmm_full<eT>::em_update_params
2421   (
2422   const Mat<eT>&           X,
2423   const umat&              boundaries,
2424         field<  Mat<eT> >& t_acc_means,
2425         field< Cube<eT> >& t_acc_fcovs,
2426         field<  Col<eT> >& t_acc_norm_lhoods,
2427         field<  Col<eT> >& t_gaus_log_lhoods,
2428         Col<eT>&           t_progress_log_lhood,
2429   const eT                 var_floor
2430   )
2431   {
2432   arma_extra_debug_sigprint();
2433 
2434   const uword n_threads = boundaries.n_cols;
2435 
2436 
2437   // em_generate_acc() is the "map" operation, which produces partial accumulators for means, diagonal covariances and hefts
2438 
2439   #if defined(ARMA_USE_OPENMP)
2440     {
2441     #pragma omp parallel for schedule(static)
2442     for(uword t=0; t<n_threads; t++)
2443       {
2444        Mat<eT>& acc_means          = t_acc_means[t];
2445       Cube<eT>& acc_fcovs          = t_acc_fcovs[t];
2446        Col<eT>& acc_norm_lhoods    = t_acc_norm_lhoods[t];
2447        Col<eT>& gaus_log_lhoods    = t_gaus_log_lhoods[t];
2448        eT&      progress_log_lhood = t_progress_log_lhood[t];
2449 
2450       em_generate_acc(X, boundaries.at(0,t), boundaries.at(1,t), acc_means, acc_fcovs, acc_norm_lhoods, gaus_log_lhoods, progress_log_lhood);
2451       }
2452     }
2453   #else
2454     {
2455     em_generate_acc(X, boundaries.at(0,0), boundaries.at(1,0), t_acc_means[0], t_acc_fcovs[0], t_acc_norm_lhoods[0], t_gaus_log_lhoods[0], t_progress_log_lhood[0]);
2456     }
2457   #endif
2458 
2459   const uword N_dims = means.n_rows;
2460   const uword N_gaus = means.n_cols;
2461 
2462    Mat<eT>& final_acc_means = t_acc_means[0];
2463   Cube<eT>& final_acc_fcovs = t_acc_fcovs[0];
2464 
2465   Col<eT>& final_acc_norm_lhoods = t_acc_norm_lhoods[0];
2466 
2467 
2468   // the "reduce" operation, which combines the partial accumulators produced by the separate threads
2469 
2470   for(uword t=1; t<n_threads; t++)
2471     {
2472     final_acc_means += t_acc_means[t];
2473     final_acc_fcovs += t_acc_fcovs[t];
2474 
2475     final_acc_norm_lhoods += t_acc_norm_lhoods[t];
2476     }
2477 
2478 
2479   eT* hefts_mem = access::rw(hefts).memptr();
2480 
2481   Mat<eT> mean_outer(N_dims, N_dims, arma_nozeros_indicator());
2482 
2483 
2484   //// update each component without sanity checking
2485   //for(uword g=0; g < N_gaus; ++g)
2486   //  {
2487   //  const eT acc_norm_lhood = (std::max)( final_acc_norm_lhoods[g], std::numeric_limits<eT>::min() );
2488   //
2489   //    hefts_mem[g] = acc_norm_lhood / eT(X.n_cols);
2490   //
2491   //    eT*     mean_mem = access::rw(means).colptr(g);
2492   //    eT* acc_mean_mem = final_acc_means.colptr(g);
2493   //
2494   //    for(uword d=0; d < N_dims; ++d)
2495   //      {
2496   //      mean_mem[d] = acc_mean_mem[d] / acc_norm_lhood;
2497   //      }
2498   //
2499   //    const Col<eT> mean(mean_mem, N_dims, false, true);
2500   //
2501   //    mean_outer = mean * mean.t();
2502   //
2503   //    Mat<eT>&     fcov = access::rw(fcovs).slice(g);
2504   //    Mat<eT>& acc_fcov = final_acc_fcovs.slice(g);
2505   //
2506   //    fcov = acc_fcov / acc_norm_lhood - mean_outer;
2507   //    }
2508 
2509 
2510   // conditionally update each component; if only a subset of the hefts was updated, em_fix_params() will sanitise them
2511   for(uword g=0; g < N_gaus; ++g)
2512     {
2513     const eT acc_norm_lhood = (std::max)( final_acc_norm_lhoods[g], std::numeric_limits<eT>::min() );
2514 
2515     if(arma_isfinite(acc_norm_lhood) == false)  { continue; }
2516 
2517     eT* acc_mean_mem = final_acc_means.colptr(g);
2518 
2519     for(uword d=0; d < N_dims; ++d)
2520       {
2521       acc_mean_mem[d] /= acc_norm_lhood;
2522       }
2523 
2524     const Col<eT> new_mean(acc_mean_mem, N_dims, false, true);
2525 
2526     mean_outer = new_mean * new_mean.t();
2527 
2528     Mat<eT>& acc_fcov = final_acc_fcovs.slice(g);
2529 
2530     acc_fcov /= acc_norm_lhood;
2531     acc_fcov -= mean_outer;
2532 
2533     for(uword d=0; d < N_dims; ++d)
2534       {
2535       eT& val = acc_fcov.at(d,d);
2536 
2537       if(val < var_floor)  { val = var_floor; }
2538       }
2539 
2540     if(acc_fcov.is_finite() == false)  { continue; }
2541 
2542     eT log_det_val  = eT(0);
2543     eT log_det_sign = eT(0);
2544 
2545     const bool log_det_status = log_det(log_det_val, log_det_sign, acc_fcov);
2546 
2547     const bool log_det_ok = ( log_det_status && (arma_isfinite(log_det_val)) && (log_det_sign > eT(0)) );
2548 
2549     const bool inv_ok = (log_det_ok) ? bool(auxlib::inv_sympd(mean_outer, acc_fcov)) : bool(false);  // mean_outer is used as a junk matrix
2550 
2551     if(log_det_ok && inv_ok)
2552       {
2553       hefts_mem[g] = acc_norm_lhood / eT(X.n_cols);
2554 
2555       eT* mean_mem = access::rw(means).colptr(g);
2556 
2557       for(uword d=0; d < N_dims; ++d)
2558         {
2559         mean_mem[d] = acc_mean_mem[d];
2560         }
2561 
2562       Mat<eT>& fcov = access::rw(fcovs).slice(g);
2563 
2564       fcov = acc_fcov;
2565       }
2566     }
2567   }
2568 
2569 
2570 
2571 template<typename eT>
2572 inline
2573 void
em_generate_acc(const Mat<eT> & X,const uword start_index,const uword end_index,Mat<eT> & acc_means,Cube<eT> & acc_fcovs,Col<eT> & acc_norm_lhoods,Col<eT> & gaus_log_lhoods,eT & progress_log_lhood) const2574 gmm_full<eT>::em_generate_acc
2575   (
2576   const  Mat<eT>& X,
2577   const  uword    start_index,
2578   const  uword      end_index,
2579          Mat<eT>& acc_means,
2580         Cube<eT>& acc_fcovs,
2581          Col<eT>& acc_norm_lhoods,
2582          Col<eT>& gaus_log_lhoods,
2583          eT&      progress_log_lhood
2584   )
2585   const
2586   {
2587   arma_extra_debug_sigprint();
2588 
2589   progress_log_lhood = eT(0);
2590 
2591   acc_means.zeros();
2592   acc_fcovs.zeros();
2593 
2594   acc_norm_lhoods.zeros();
2595   gaus_log_lhoods.zeros();
2596 
2597   const uword N_dims = means.n_rows;
2598   const uword N_gaus = means.n_cols;
2599 
2600   const eT* log_hefts_mem       = log_hefts.memptr();
2601         eT* gaus_log_lhoods_mem = gaus_log_lhoods.memptr();
2602 
2603 
2604   for(uword i=start_index; i <= end_index; i++)
2605     {
2606     const eT* x = X.colptr(i);
2607 
2608     for(uword g=0; g < N_gaus; ++g)
2609       {
2610       gaus_log_lhoods_mem[g] = internal_scalar_log_p(x, g) + log_hefts_mem[g];
2611       }
2612 
2613     eT log_lhood_sum = gaus_log_lhoods_mem[0];
2614 
2615     for(uword g=1; g < N_gaus; ++g)
2616       {
2617       log_lhood_sum = log_add_exp(log_lhood_sum, gaus_log_lhoods_mem[g]);
2618       }
2619 
2620     progress_log_lhood += log_lhood_sum;
2621 
2622     for(uword g=0; g < N_gaus; ++g)
2623       {
2624       const eT norm_lhood = std::exp(gaus_log_lhoods_mem[g] - log_lhood_sum);
2625 
2626       acc_norm_lhoods[g] += norm_lhood;
2627 
2628       eT* acc_mean_mem = acc_means.colptr(g);
2629 
2630       for(uword d=0; d < N_dims; ++d)
2631         {
2632         acc_mean_mem[d] += x[d] * norm_lhood;
2633         }
2634 
2635       Mat<eT>& acc_fcov = access::rw(acc_fcovs).slice(g);
2636 
2637       // specialised version of acc_fcov += norm_lhood * (xx * xx.t());
2638 
2639       for(uword d=0; d < N_dims; ++d)
2640         {
2641         const uword dp1 = d+1;
2642 
2643         const eT xd = x[d];
2644 
2645         eT* acc_fcov_col_d = acc_fcov.colptr(d) + d;
2646         eT* acc_fcov_row_d = &(acc_fcov.at(d,dp1));
2647 
2648         (*acc_fcov_col_d) += norm_lhood * (xd * xd);  acc_fcov_col_d++;
2649 
2650         for(uword e=dp1; e < N_dims; ++e)
2651           {
2652           const eT val = norm_lhood * (xd * x[e]);
2653 
2654           (*acc_fcov_col_d) += val;  acc_fcov_col_d++;
2655           (*acc_fcov_row_d) += val;  acc_fcov_row_d += N_dims;
2656           }
2657         }
2658       }
2659     }
2660 
2661   progress_log_lhood /= eT((end_index - start_index) + 1);
2662   }
2663 
2664 
2665 
2666 template<typename eT>
2667 inline
2668 void
em_fix_params(const eT var_floor)2669 gmm_full<eT>::em_fix_params(const eT var_floor)
2670   {
2671   arma_extra_debug_sigprint();
2672 
2673   const uword N_dims = means.n_rows;
2674   const uword N_gaus = means.n_cols;
2675 
2676   const eT var_ceiling = std::numeric_limits<eT>::max();
2677 
2678   for(uword g=0; g < N_gaus; ++g)
2679     {
2680     Mat<eT>& fcov = access::rw(fcovs).slice(g);
2681 
2682     for(uword d=0; d < N_dims; ++d)
2683       {
2684       eT& var_val = fcov.at(d,d);
2685 
2686            if(var_val < var_floor  )  { var_val = var_floor;   }
2687       else if(var_val > var_ceiling)  { var_val = var_ceiling; }
2688       else if(arma_isnan(var_val)  )  { var_val = eT(1);       }
2689       }
2690     }
2691 
2692 
2693   eT* hefts_mem = access::rw(hefts).memptr();
2694 
2695   for(uword g1=0; g1 < N_gaus; ++g1)
2696     {
2697     if(hefts_mem[g1] > eT(0))
2698       {
2699       const eT* means_colptr_g1 = means.colptr(g1);
2700 
2701       for(uword g2=(g1+1); g2 < N_gaus; ++g2)
2702         {
2703         if( (hefts_mem[g2] > eT(0)) && (std::abs(hefts_mem[g1] - hefts_mem[g2]) <= std::numeric_limits<eT>::epsilon()) )
2704           {
2705           const eT dist = distance<eT,1>::eval(N_dims, means_colptr_g1, means.colptr(g2), means_colptr_g1);
2706 
2707           if(dist == eT(0)) { hefts_mem[g2] = eT(0); }
2708           }
2709         }
2710       }
2711     }
2712 
2713   const eT heft_floor   = std::numeric_limits<eT>::min();
2714   const eT heft_initial = eT(1) / eT(N_gaus);
2715 
2716   for(uword i=0; i < N_gaus; ++i)
2717     {
2718     eT& heft_val = hefts_mem[i];
2719 
2720          if(heft_val < heft_floor)  { heft_val = heft_floor;   }
2721     else if(heft_val > eT(1)     )  { heft_val = eT(1);        }
2722     else if(arma_isnan(heft_val) )  { heft_val = heft_initial; }
2723     }
2724 
2725   const eT heft_sum = accu(hefts);
2726 
2727   if((heft_sum < (eT(1) - Datum<eT>::eps)) || (heft_sum > (eT(1) + Datum<eT>::eps)))  { access::rw(hefts) /= heft_sum; }
2728   }
2729 
2730 
2731 
2732 } // namespace gmm_priv
2733 
2734 
2735 //! @}
2736