1 // SPDX-License-Identifier: Apache-2.0
2 //
3 // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
4 // Copyright 2008-2016 National ICT Australia (NICTA)
5 //
6 // Licensed under the Apache License, Version 2.0 (the "License");
7 // you may not use this file except in compliance with the License.
8 // You may obtain a copy of the License at
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // ------------------------------------------------------------------------
17
18
19 //! \addtogroup gmm_diag
20 //! @{
21
22
23 namespace gmm_priv
24 {
25
26
27 template<typename eT>
28 inline
~gmm_diag()29 gmm_diag<eT>::~gmm_diag()
30 {
31 arma_extra_debug_sigprint_this(this);
32
33 arma_type_check(( (is_same_type<eT,float>::value == false) && (is_same_type<eT,double>::value == false) ));
34 }
35
36
37
38 template<typename eT>
39 inline
gmm_diag()40 gmm_diag<eT>::gmm_diag()
41 {
42 arma_extra_debug_sigprint_this(this);
43 }
44
45
46
47 template<typename eT>
48 inline
gmm_diag(const gmm_diag<eT> & x)49 gmm_diag<eT>::gmm_diag(const gmm_diag<eT>& x)
50 {
51 arma_extra_debug_sigprint_this(this);
52
53 init(x);
54 }
55
56
57
58 template<typename eT>
59 inline
60 gmm_diag<eT>&
operator =(const gmm_diag<eT> & x)61 gmm_diag<eT>::operator=(const gmm_diag<eT>& x)
62 {
63 arma_extra_debug_sigprint();
64
65 init(x);
66
67 return *this;
68 }
69
70
71
72 template<typename eT>
73 inline
gmm_diag(const gmm_full<eT> & x)74 gmm_diag<eT>::gmm_diag(const gmm_full<eT>& x)
75 {
76 arma_extra_debug_sigprint_this(this);
77
78 init(x);
79 }
80
81
82
83 template<typename eT>
84 inline
85 gmm_diag<eT>&
operator =(const gmm_full<eT> & x)86 gmm_diag<eT>::operator=(const gmm_full<eT>& x)
87 {
88 arma_extra_debug_sigprint();
89
90 init(x);
91
92 return *this;
93 }
94
95
96
97 template<typename eT>
98 inline
gmm_diag(const uword in_n_dims,const uword in_n_gaus)99 gmm_diag<eT>::gmm_diag(const uword in_n_dims, const uword in_n_gaus)
100 {
101 arma_extra_debug_sigprint_this(this);
102
103 init(in_n_dims, in_n_gaus);
104 }
105
106
107
108 template<typename eT>
109 inline
110 void
reset()111 gmm_diag<eT>::reset()
112 {
113 arma_extra_debug_sigprint();
114
115 init(0, 0);
116 }
117
118
119
120 template<typename eT>
121 inline
122 void
reset(const uword in_n_dims,const uword in_n_gaus)123 gmm_diag<eT>::reset(const uword in_n_dims, const uword in_n_gaus)
124 {
125 arma_extra_debug_sigprint();
126
127 init(in_n_dims, in_n_gaus);
128 }
129
130
131
132 template<typename eT>
133 template<typename T1, typename T2, typename T3>
134 inline
135 void
set_params(const Base<eT,T1> & in_means_expr,const Base<eT,T2> & in_dcovs_expr,const Base<eT,T3> & in_hefts_expr)136 gmm_diag<eT>::set_params(const Base<eT,T1>& in_means_expr, const Base<eT,T2>& in_dcovs_expr, const Base<eT,T3>& in_hefts_expr)
137 {
138 arma_extra_debug_sigprint();
139
140 const unwrap<T1> tmp1(in_means_expr.get_ref());
141 const unwrap<T2> tmp2(in_dcovs_expr.get_ref());
142 const unwrap<T3> tmp3(in_hefts_expr.get_ref());
143
144 const Mat<eT>& in_means = tmp1.M;
145 const Mat<eT>& in_dcovs = tmp2.M;
146 const Mat<eT>& in_hefts = tmp3.M;
147
148 arma_debug_check
149 (
150 (arma::size(in_means) != arma::size(in_dcovs)) || (in_hefts.n_cols != in_means.n_cols) || (in_hefts.n_rows != 1),
151 "gmm_diag::set_params(): given parameters have inconsistent and/or wrong sizes"
152 );
153
154 arma_debug_check( (in_means.is_finite() == false), "gmm_diag::set_params(): given means have non-finite values" );
155 arma_debug_check( (in_dcovs.is_finite() == false), "gmm_diag::set_params(): given dcovs have non-finite values" );
156 arma_debug_check( (in_hefts.is_finite() == false), "gmm_diag::set_params(): given hefts have non-finite values" );
157
158 arma_debug_check( (any(vectorise(in_dcovs) <= eT(0))), "gmm_diag::set_params(): given dcovs have negative or zero values" );
159 arma_debug_check( (any(vectorise(in_hefts) < eT(0))), "gmm_diag::set_params(): given hefts have negative values" );
160
161 const eT s = accu(in_hefts);
162
163 arma_debug_check( ((s < (eT(1) - eT(0.001))) || (s > (eT(1) + eT(0.001)))), "gmm_diag::set_params(): sum of given hefts is not 1" );
164
165 access::rw(means) = in_means;
166 access::rw(dcovs) = in_dcovs;
167 access::rw(hefts) = in_hefts;
168
169 init_constants();
170 }
171
172
173
174 template<typename eT>
175 template<typename T1>
176 inline
177 void
set_means(const Base<eT,T1> & in_means_expr)178 gmm_diag<eT>::set_means(const Base<eT,T1>& in_means_expr)
179 {
180 arma_extra_debug_sigprint();
181
182 const unwrap<T1> tmp(in_means_expr.get_ref());
183
184 const Mat<eT>& in_means = tmp.M;
185
186 arma_debug_check( (arma::size(in_means) != arma::size(means)), "gmm_diag::set_means(): given means have incompatible size" );
187 arma_debug_check( (in_means.is_finite() == false), "gmm_diag::set_means(): given means have non-finite values" );
188
189 access::rw(means) = in_means;
190 }
191
192
193
194 template<typename eT>
195 template<typename T1>
196 inline
197 void
set_dcovs(const Base<eT,T1> & in_dcovs_expr)198 gmm_diag<eT>::set_dcovs(const Base<eT,T1>& in_dcovs_expr)
199 {
200 arma_extra_debug_sigprint();
201
202 const unwrap<T1> tmp(in_dcovs_expr.get_ref());
203
204 const Mat<eT>& in_dcovs = tmp.M;
205
206 arma_debug_check( (arma::size(in_dcovs) != arma::size(dcovs)), "gmm_diag::set_dcovs(): given dcovs have incompatible size" );
207 arma_debug_check( (in_dcovs.is_finite() == false), "gmm_diag::set_dcovs(): given dcovs have non-finite values" );
208 arma_debug_check( (any(vectorise(in_dcovs) <= eT(0))), "gmm_diag::set_dcovs(): given dcovs have negative or zero values" );
209
210 access::rw(dcovs) = in_dcovs;
211
212 init_constants();
213 }
214
215
216
217 template<typename eT>
218 template<typename T1>
219 inline
220 void
set_hefts(const Base<eT,T1> & in_hefts_expr)221 gmm_diag<eT>::set_hefts(const Base<eT,T1>& in_hefts_expr)
222 {
223 arma_extra_debug_sigprint();
224
225 const unwrap<T1> tmp(in_hefts_expr.get_ref());
226
227 const Mat<eT>& in_hefts = tmp.M;
228
229 arma_debug_check( (arma::size(in_hefts) != arma::size(hefts)), "gmm_diag::set_hefts(): given hefts have incompatible size" );
230 arma_debug_check( (in_hefts.is_finite() == false), "gmm_diag::set_hefts(): given hefts have non-finite values" );
231 arma_debug_check( (any(vectorise(in_hefts) < eT(0))), "gmm_diag::set_hefts(): given hefts have negative values" );
232
233 const eT s = accu(in_hefts);
234
235 arma_debug_check( ((s < (eT(1) - eT(0.001))) || (s > (eT(1) + eT(0.001)))), "gmm_diag::set_hefts(): sum of given hefts is not 1" );
236
237 // make sure all hefts are positive and non-zero
238
239 const eT* in_hefts_mem = in_hefts.memptr();
240 eT* hefts_mem = access::rw(hefts).memptr();
241
242 for(uword i=0; i < hefts.n_elem; ++i)
243 {
244 hefts_mem[i] = (std::max)( in_hefts_mem[i], std::numeric_limits<eT>::min() );
245 }
246
247 access::rw(hefts) /= accu(hefts);
248
249 log_hefts = log(hefts);
250 }
251
252
253
254 template<typename eT>
255 inline
256 uword
n_dims() const257 gmm_diag<eT>::n_dims() const
258 {
259 return means.n_rows;
260 }
261
262
263
264 template<typename eT>
265 inline
266 uword
n_gaus() const267 gmm_diag<eT>::n_gaus() const
268 {
269 return means.n_cols;
270 }
271
272
273
274 template<typename eT>
275 inline
276 bool
load(const std::string name)277 gmm_diag<eT>::load(const std::string name)
278 {
279 arma_extra_debug_sigprint();
280
281 Cube<eT> Q;
282
283 bool status = Q.load(name, arma_binary);
284
285 if( (status == false) || (Q.n_slices != 2) )
286 {
287 reset();
288 arma_debug_warn_level(3, "gmm_diag::load(): problem with loading or incompatible format");
289 return false;
290 }
291
292 if( (Q.n_rows < 2) || (Q.n_cols < 1) )
293 {
294 reset();
295 return true;
296 }
297
298 access::rw(hefts) = Q.slice(0).row(0);
299 access::rw(means) = Q.slice(0).submat(1, 0, Q.n_rows-1, Q.n_cols-1);
300 access::rw(dcovs) = Q.slice(1).submat(1, 0, Q.n_rows-1, Q.n_cols-1);
301
302 init_constants();
303
304 return true;
305 }
306
307
308
309 template<typename eT>
310 inline
311 bool
save(const std::string name) const312 gmm_diag<eT>::save(const std::string name) const
313 {
314 arma_extra_debug_sigprint();
315
316 Cube<eT> Q(means.n_rows + 1, means.n_cols, 2, arma_nozeros_indicator());
317
318 if(Q.n_elem > 0)
319 {
320 Q.slice(0).row(0) = hefts;
321 Q.slice(1).row(0).zeros(); // reserved for future use
322
323 Q.slice(0).submat(1, 0, arma::size(means)) = means;
324 Q.slice(1).submat(1, 0, arma::size(dcovs)) = dcovs;
325 }
326
327 const bool status = Q.save(name, arma_binary);
328
329 return status;
330 }
331
332
333
334 template<typename eT>
335 inline
336 Col<eT>
generate() const337 gmm_diag<eT>::generate() const
338 {
339 arma_extra_debug_sigprint();
340
341 const uword N_dims = means.n_rows;
342 const uword N_gaus = means.n_cols;
343
344 Col<eT> out( ((N_gaus > 0) ? N_dims : uword(0)), fill::randn );
345
346 if(N_gaus > 0)
347 {
348 const double val = randu<double>();
349
350 double csum = double(0);
351 uword gaus_id = 0;
352
353 for(uword j=0; j < N_gaus; ++j)
354 {
355 csum += hefts[j];
356
357 if(val <= csum) { gaus_id = j; break; }
358 }
359
360 out %= sqrt(dcovs.col(gaus_id));
361 out += means.col(gaus_id);
362 }
363
364 return out;
365 }
366
367
368
369 template<typename eT>
370 inline
371 Mat<eT>
generate(const uword N_vec) const372 gmm_diag<eT>::generate(const uword N_vec) const
373 {
374 arma_extra_debug_sigprint();
375
376 const uword N_dims = means.n_rows;
377 const uword N_gaus = means.n_cols;
378
379 Mat<eT> out( ( (N_gaus > 0) ? N_dims : uword(0) ), N_vec, fill::randn );
380
381 if(N_gaus > 0)
382 {
383 const eT* hefts_mem = hefts.memptr();
384
385 const Mat<eT> sqrt_dcovs = sqrt(dcovs);
386
387 for(uword i=0; i < N_vec; ++i)
388 {
389 const double val = randu<double>();
390
391 double csum = double(0);
392 uword gaus_id = 0;
393
394 for(uword j=0; j < N_gaus; ++j)
395 {
396 csum += hefts_mem[j];
397
398 if(val <= csum) { gaus_id = j; break; }
399 }
400
401 subview_col<eT> out_col = out.col(i);
402
403 out_col %= sqrt_dcovs.col(gaus_id);
404 out_col += means.col(gaus_id);
405 }
406 }
407
408 return out;
409 }
410
411
412
413 template<typename eT>
414 template<typename T1>
415 inline
416 eT
log_p(const T1 & expr,const gmm_empty_arg & junk1,typename enable_if<((is_arma_type<T1>::value)&& (resolves_to_colvector<T1>::value==true))>::result * junk2) const417 gmm_diag<eT>::log_p(const T1& expr, const gmm_empty_arg& junk1, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == true))>::result* junk2) const
418 {
419 arma_extra_debug_sigprint();
420 arma_ignore(junk1);
421 arma_ignore(junk2);
422
423 const quasi_unwrap<T1> tmp(expr);
424
425 arma_debug_check( (tmp.M.n_rows != means.n_rows), "gmm_diag::log_p(): incompatible dimensions" );
426
427 return internal_scalar_log_p( tmp.M.memptr() );
428 }
429
430
431
432 template<typename eT>
433 template<typename T1>
434 inline
435 eT
log_p(const T1 & expr,const uword gaus_id,typename enable_if<((is_arma_type<T1>::value)&& (resolves_to_colvector<T1>::value==true))>::result * junk2) const436 gmm_diag<eT>::log_p(const T1& expr, const uword gaus_id, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == true))>::result* junk2) const
437 {
438 arma_extra_debug_sigprint();
439 arma_ignore(junk2);
440
441 const quasi_unwrap<T1> tmp(expr);
442
443 arma_debug_check( (tmp.M.n_rows != means.n_rows), "gmm_diag::log_p(): incompatible dimensions" );
444
445 arma_debug_check( (gaus_id >= means.n_cols), "gmm_diag::log_p(): specified gaussian is out of range" );
446
447 return internal_scalar_log_p( tmp.M.memptr(), gaus_id );
448 }
449
450
451
452 template<typename eT>
453 template<typename T1>
454 inline
455 Row<eT>
log_p(const T1 & expr,const gmm_empty_arg & junk1,typename enable_if<((is_arma_type<T1>::value)&& (resolves_to_colvector<T1>::value==false))>::result * junk2) const456 gmm_diag<eT>::log_p(const T1& expr, const gmm_empty_arg& junk1, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == false))>::result* junk2) const
457 {
458 arma_extra_debug_sigprint();
459 arma_ignore(junk1);
460 arma_ignore(junk2);
461
462 const quasi_unwrap<T1> tmp(expr);
463
464 const Mat<eT>& X = tmp.M;
465
466 return internal_vec_log_p(X);
467 }
468
469
470
471 template<typename eT>
472 template<typename T1>
473 inline
474 Row<eT>
log_p(const T1 & expr,const uword gaus_id,typename enable_if<((is_arma_type<T1>::value)&& (resolves_to_colvector<T1>::value==false))>::result * junk2) const475 gmm_diag<eT>::log_p(const T1& expr, const uword gaus_id, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == false))>::result* junk2) const
476 {
477 arma_extra_debug_sigprint();
478 arma_ignore(junk2);
479
480 const quasi_unwrap<T1> tmp(expr);
481
482 const Mat<eT>& X = tmp.M;
483
484 return internal_vec_log_p(X, gaus_id);
485 }
486
487
488
489 template<typename eT>
490 template<typename T1>
491 inline
492 eT
sum_log_p(const Base<eT,T1> & expr) const493 gmm_diag<eT>::sum_log_p(const Base<eT,T1>& expr) const
494 {
495 arma_extra_debug_sigprint();
496
497 const quasi_unwrap<T1> tmp(expr.get_ref());
498
499 const Mat<eT>& X = tmp.M;
500
501 return internal_sum_log_p(X);
502 }
503
504
505
506 template<typename eT>
507 template<typename T1>
508 inline
509 eT
sum_log_p(const Base<eT,T1> & expr,const uword gaus_id) const510 gmm_diag<eT>::sum_log_p(const Base<eT,T1>& expr, const uword gaus_id) const
511 {
512 arma_extra_debug_sigprint();
513
514 const quasi_unwrap<T1> tmp(expr.get_ref());
515
516 const Mat<eT>& X = tmp.M;
517
518 return internal_sum_log_p(X, gaus_id);
519 }
520
521
522
523 template<typename eT>
524 template<typename T1>
525 inline
526 eT
avg_log_p(const Base<eT,T1> & expr) const527 gmm_diag<eT>::avg_log_p(const Base<eT,T1>& expr) const
528 {
529 arma_extra_debug_sigprint();
530
531 const quasi_unwrap<T1> tmp(expr.get_ref());
532
533 const Mat<eT>& X = tmp.M;
534
535 return internal_avg_log_p(X);
536 }
537
538
539
540 template<typename eT>
541 template<typename T1>
542 inline
543 eT
avg_log_p(const Base<eT,T1> & expr,const uword gaus_id) const544 gmm_diag<eT>::avg_log_p(const Base<eT,T1>& expr, const uword gaus_id) const
545 {
546 arma_extra_debug_sigprint();
547
548 const quasi_unwrap<T1> tmp(expr.get_ref());
549
550 const Mat<eT>& X = tmp.M;
551
552 return internal_avg_log_p(X, gaus_id);
553 }
554
555
556
557 template<typename eT>
558 template<typename T1>
559 inline
560 uword
assign(const T1 & expr,const gmm_dist_mode & dist,typename enable_if<((is_arma_type<T1>::value)&& (resolves_to_colvector<T1>::value==true))>::result * junk) const561 gmm_diag<eT>::assign(const T1& expr, const gmm_dist_mode& dist, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == true))>::result* junk) const
562 {
563 arma_extra_debug_sigprint();
564 arma_ignore(junk);
565
566 const quasi_unwrap<T1> tmp(expr);
567
568 const Mat<eT>& X = tmp.M;
569
570 return internal_scalar_assign(X, dist);
571 }
572
573
574
575 template<typename eT>
576 template<typename T1>
577 inline
578 urowvec
assign(const T1 & expr,const gmm_dist_mode & dist,typename enable_if<((is_arma_type<T1>::value)&& (resolves_to_colvector<T1>::value==false))>::result * junk) const579 gmm_diag<eT>::assign(const T1& expr, const gmm_dist_mode& dist, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == false))>::result* junk) const
580 {
581 arma_extra_debug_sigprint();
582 arma_ignore(junk);
583
584 urowvec out;
585
586 const quasi_unwrap<T1> tmp(expr);
587
588 const Mat<eT>& X = tmp.M;
589
590 internal_vec_assign(out, X, dist);
591
592 return out;
593 }
594
595
596
597 template<typename eT>
598 template<typename T1>
599 inline
600 urowvec
raw_hist(const Base<eT,T1> & expr,const gmm_dist_mode & dist_mode) const601 gmm_diag<eT>::raw_hist(const Base<eT,T1>& expr, const gmm_dist_mode& dist_mode) const
602 {
603 arma_extra_debug_sigprint();
604
605 const unwrap<T1> tmp(expr.get_ref());
606 const Mat<eT>& X = tmp.M;
607
608 arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::raw_hist(): incompatible dimensions" );
609
610 arma_debug_check( ((dist_mode != eucl_dist) && (dist_mode != prob_dist)), "gmm_diag::raw_hist(): unsupported distance mode" );
611
612 urowvec hist;
613
614 internal_raw_hist(hist, X, dist_mode);
615
616 return hist;
617 }
618
619
620
621 template<typename eT>
622 template<typename T1>
623 inline
624 Row<eT>
norm_hist(const Base<eT,T1> & expr,const gmm_dist_mode & dist_mode) const625 gmm_diag<eT>::norm_hist(const Base<eT,T1>& expr, const gmm_dist_mode& dist_mode) const
626 {
627 arma_extra_debug_sigprint();
628
629 const unwrap<T1> tmp(expr.get_ref());
630 const Mat<eT>& X = tmp.M;
631
632 arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::norm_hist(): incompatible dimensions" );
633
634 arma_debug_check( ((dist_mode != eucl_dist) && (dist_mode != prob_dist)), "gmm_diag::norm_hist(): unsupported distance mode" );
635
636 urowvec hist;
637
638 internal_raw_hist(hist, X, dist_mode);
639
640 const uword hist_n_elem = hist.n_elem;
641 const uword* hist_mem = hist.memptr();
642
643 eT acc = eT(0);
644 for(uword i=0; i<hist_n_elem; ++i) { acc += eT(hist_mem[i]); }
645
646 if(acc == eT(0)) { acc = eT(1); }
647
648 Row<eT> out(hist_n_elem, arma_nozeros_indicator());
649
650 eT* out_mem = out.memptr();
651
652 for(uword i=0; i<hist_n_elem; ++i) { out_mem[i] = eT(hist_mem[i]) / acc; }
653
654 return out;
655 }
656
657
658
659 template<typename eT>
660 template<typename T1>
661 inline
662 bool
learn(const Base<eT,T1> & data,const uword N_gaus,const gmm_dist_mode & dist_mode,const gmm_seed_mode & seed_mode,const uword km_iter,const uword em_iter,const eT var_floor,const bool print_mode)663 gmm_diag<eT>::learn
664 (
665 const Base<eT,T1>& data,
666 const uword N_gaus,
667 const gmm_dist_mode& dist_mode,
668 const gmm_seed_mode& seed_mode,
669 const uword km_iter,
670 const uword em_iter,
671 const eT var_floor,
672 const bool print_mode
673 )
674 {
675 arma_extra_debug_sigprint();
676
677 const bool dist_mode_ok = (dist_mode == eucl_dist) || (dist_mode == maha_dist);
678
679 const bool seed_mode_ok = \
680 (seed_mode == keep_existing)
681 || (seed_mode == static_subset)
682 || (seed_mode == static_spread)
683 || (seed_mode == random_subset)
684 || (seed_mode == random_spread);
685
686 arma_debug_check( (dist_mode_ok == false), "gmm_diag::learn(): dist_mode must be eucl_dist or maha_dist" );
687 arma_debug_check( (seed_mode_ok == false), "gmm_diag::learn(): unknown seed_mode" );
688 arma_debug_check( (var_floor < eT(0) ), "gmm_diag::learn(): variance floor is negative" );
689
690 const unwrap<T1> tmp_X(data.get_ref());
691 const Mat<eT>& X = tmp_X.M;
692
693 if(X.is_empty() ) { arma_debug_warn_level(3, "gmm_diag::learn(): given matrix is empty" ); return false; }
694 if(X.is_finite() == false) { arma_debug_warn_level(3, "gmm_diag::learn(): given matrix has non-finite values"); return false; }
695
696 if(N_gaus == 0) { reset(); return true; }
697
698 if(dist_mode == maha_dist)
699 {
700 mah_aux = var(X,1,1);
701
702 const uword mah_aux_n_elem = mah_aux.n_elem;
703 eT* mah_aux_mem = mah_aux.memptr();
704
705 for(uword i=0; i < mah_aux_n_elem; ++i)
706 {
707 const eT val = mah_aux_mem[i];
708
709 mah_aux_mem[i] = ((val != eT(0)) && arma_isfinite(val)) ? eT(1) / val : eT(1);
710 }
711 }
712
713
714 // copy current model, in case of failure by k-means and/or EM
715
716 const gmm_diag<eT> orig = (*this);
717
718
719 // initial means
720
721 if(seed_mode == keep_existing)
722 {
723 if(means.is_empty() ) { arma_debug_warn_level(3, "gmm_diag::learn(): no existing means" ); return false; }
724 if(X.n_rows != means.n_rows) { arma_debug_warn_level(3, "gmm_diag::learn(): dimensionality mismatch"); return false; }
725
726 // TODO: also check for number of vectors?
727 }
728 else
729 {
730 if(X.n_cols < N_gaus) { arma_debug_warn_level(3, "gmm_diag::learn(): number of vectors is less than number of gaussians"); return false; }
731
732 reset(X.n_rows, N_gaus);
733
734 if(print_mode) { get_cout_stream() << "gmm_diag::learn(): generating initial means\n"; get_cout_stream().flush(); }
735
736 if(dist_mode == eucl_dist) { generate_initial_means<1>(X, seed_mode); }
737 else if(dist_mode == maha_dist) { generate_initial_means<2>(X, seed_mode); }
738 }
739
740
741 // k-means
742
743 if(km_iter > 0)
744 {
745 const arma_ostream_state stream_state(get_cout_stream());
746
747 bool status = false;
748
749 if(dist_mode == eucl_dist) { status = km_iterate<1>(X, km_iter, print_mode, "gmm_diag::learn(): k-means"); }
750 else if(dist_mode == maha_dist) { status = km_iterate<2>(X, km_iter, print_mode, "gmm_diag::learn(): k-means"); }
751
752 stream_state.restore(get_cout_stream());
753
754 if(status == false) { arma_debug_warn_level(3, "gmm_diag::learn(): k-means algorithm failed; not enough data, or too many gaussians requested"); init(orig); return false; }
755 }
756
757
758 // initial dcovs
759
760 const eT var_floor_actual = (eT(var_floor) > eT(0)) ? eT(var_floor) : std::numeric_limits<eT>::min();
761
762 if(seed_mode != keep_existing)
763 {
764 if(print_mode) { get_cout_stream() << "gmm_diag::learn(): generating initial covariances\n"; get_cout_stream().flush(); }
765
766 if(dist_mode == eucl_dist) { generate_initial_params<1>(X, var_floor_actual); }
767 else if(dist_mode == maha_dist) { generate_initial_params<2>(X, var_floor_actual); }
768 }
769
770
771 // EM algorithm
772
773 if(em_iter > 0)
774 {
775 const arma_ostream_state stream_state(get_cout_stream());
776
777 const bool status = em_iterate(X, em_iter, var_floor_actual, print_mode);
778
779 stream_state.restore(get_cout_stream());
780
781 if(status == false) { arma_debug_warn_level(3, "gmm_diag::learn(): EM algorithm failed"); init(orig); return false; }
782 }
783
784 mah_aux.reset();
785
786 init_constants();
787
788 return true;
789 }
790
791
792
793 template<typename eT>
794 template<typename T1>
795 inline
796 bool
kmeans_wrapper(Mat<eT> & user_means,const Base<eT,T1> & data,const uword N_gaus,const gmm_seed_mode & seed_mode,const uword km_iter,const bool print_mode)797 gmm_diag<eT>::kmeans_wrapper
798 (
799 Mat<eT>& user_means,
800 const Base<eT,T1>& data,
801 const uword N_gaus,
802 const gmm_seed_mode& seed_mode,
803 const uword km_iter,
804 const bool print_mode
805 )
806 {
807 arma_extra_debug_sigprint();
808
809 const bool seed_mode_ok = \
810 (seed_mode == keep_existing)
811 || (seed_mode == static_subset)
812 || (seed_mode == static_spread)
813 || (seed_mode == random_subset)
814 || (seed_mode == random_spread);
815
816 arma_debug_check( (seed_mode_ok == false), "kmeans(): unknown seed_mode" );
817
818 const unwrap<T1> tmp_X(data.get_ref());
819 const Mat<eT>& X = tmp_X.M;
820
821 if(X.is_empty() ) { arma_debug_warn_level(3, "kmeans(): given matrix is empty" ); return false; }
822 if(X.is_finite() == false) { arma_debug_warn_level(3, "kmeans(): given matrix has non-finite values"); return false; }
823
824 if(N_gaus == 0) { reset(); return true; }
825
826
827 // initial means
828
829 if(seed_mode == keep_existing)
830 {
831 access::rw(means) = user_means;
832
833 if(means.is_empty() ) { arma_debug_warn_level(3, "kmeans(): no existing means" ); return false; }
834 if(X.n_rows != means.n_rows) { arma_debug_warn_level(3, "kmeans(): dimensionality mismatch"); return false; }
835
836 // TODO: also check for number of vectors?
837 }
838 else
839 {
840 if(X.n_cols < N_gaus) { arma_debug_warn_level(3, "kmeans(): number of vectors is less than number of means"); return false; }
841
842 access::rw(means).zeros(X.n_rows, N_gaus);
843
844 if(print_mode) { get_cout_stream() << "kmeans(): generating initial means\n"; }
845
846 generate_initial_means<1>(X, seed_mode);
847 }
848
849
850 // k-means
851
852 if(km_iter > 0)
853 {
854 const arma_ostream_state stream_state(get_cout_stream());
855
856 bool status = false;
857
858 status = km_iterate<1>(X, km_iter, print_mode, "kmeans()");
859
860 stream_state.restore(get_cout_stream());
861
862 if(status == false) { arma_debug_warn_level(3, "kmeans(): clustering failed; not enough data, or too many means requested"); return false; }
863 }
864
865 return true;
866 }
867
868
869
870 //
871 //
872 //
873
874
875
876 template<typename eT>
877 inline
878 void
init(const gmm_diag<eT> & x)879 gmm_diag<eT>::init(const gmm_diag<eT>& x)
880 {
881 arma_extra_debug_sigprint();
882
883 gmm_diag<eT>& t = *this;
884
885 if(&t != &x)
886 {
887 access::rw(t.means) = x.means;
888 access::rw(t.dcovs) = x.dcovs;
889 access::rw(t.hefts) = x.hefts;
890
891 init_constants();
892 }
893 }
894
895
896
897 template<typename eT>
898 inline
899 void
init(const gmm_full<eT> & x)900 gmm_diag<eT>::init(const gmm_full<eT>& x)
901 {
902 arma_extra_debug_sigprint();
903
904 access::rw(hefts) = x.hefts;
905 access::rw(means) = x.means;
906
907 const uword N_dims = x.means.n_rows;
908 const uword N_gaus = x.means.n_cols;
909
910 access::rw(dcovs).zeros(N_dims,N_gaus);
911
912 for(uword g=0; g < N_gaus; ++g)
913 {
914 const Mat<eT>& fcov = x.fcovs.slice(g);
915
916 eT* dcov_mem = access::rw(dcovs).colptr(g);
917
918 for(uword d=0; d < N_dims; ++d)
919 {
920 dcov_mem[d] = fcov.at(d,d);
921 }
922 }
923
924 init_constants();
925 }
926
927
928
929 template<typename eT>
930 inline
931 void
init(const uword in_n_dims,const uword in_n_gaus)932 gmm_diag<eT>::init(const uword in_n_dims, const uword in_n_gaus)
933 {
934 arma_extra_debug_sigprint();
935
936 access::rw(means).zeros(in_n_dims, in_n_gaus);
937
938 access::rw(dcovs).ones(in_n_dims, in_n_gaus);
939
940 access::rw(hefts).set_size(in_n_gaus);
941
942 access::rw(hefts).fill(eT(1) / eT(in_n_gaus));
943
944 init_constants();
945 }
946
947
948
949 template<typename eT>
950 inline
951 void
init_constants()952 gmm_diag<eT>::init_constants()
953 {
954 arma_extra_debug_sigprint();
955
956 const uword N_dims = means.n_rows;
957 const uword N_gaus = means.n_cols;
958
959 //
960
961 inv_dcovs.copy_size(dcovs);
962
963 const eT* dcovs_mem = dcovs.memptr();
964 eT* inv_dcovs_mem = inv_dcovs.memptr();
965
966 const uword dcovs_n_elem = dcovs.n_elem;
967
968 for(uword i=0; i < dcovs_n_elem; ++i)
969 {
970 inv_dcovs_mem[i] = eT(1) / (std::max)( dcovs_mem[i], std::numeric_limits<eT>::min() );
971 }
972
973 //
974
975 const eT tmp = (eT(N_dims)/eT(2)) * std::log(eT(2) * Datum<eT>::pi);
976
977 log_det_etc.set_size(N_gaus);
978
979 for(uword g=0; g < N_gaus; ++g)
980 {
981 const eT* dcovs_colmem = dcovs.colptr(g);
982
983 eT log_det_val = eT(0);
984
985 for(uword d=0; d < N_dims; ++d)
986 {
987 log_det_val += std::log( (std::max)( dcovs_colmem[d], std::numeric_limits<eT>::min() ) );
988 }
989
990 log_det_etc[g] = eT(-1) * ( tmp + eT(0.5) * log_det_val );
991 }
992
993 //
994
995 eT* hefts_mem = access::rw(hefts).memptr();
996
997 for(uword g=0; g < N_gaus; ++g)
998 {
999 hefts_mem[g] = (std::max)( hefts_mem[g], std::numeric_limits<eT>::min() );
1000 }
1001
1002 log_hefts = log(hefts);
1003 }
1004
1005
1006
1007 template<typename eT>
1008 inline
1009 umat
internal_gen_boundaries(const uword N) const1010 gmm_diag<eT>::internal_gen_boundaries(const uword N) const
1011 {
1012 arma_extra_debug_sigprint();
1013
1014 #if defined(ARMA_USE_OPENMP)
1015 const uword n_threads_avail = (omp_in_parallel()) ? uword(1) : uword(omp_get_max_threads());
1016 const uword n_threads = (n_threads_avail > 0) ? ( (n_threads_avail <= N) ? n_threads_avail : 1 ) : 1;
1017 #else
1018 static constexpr uword n_threads = 1;
1019 #endif
1020
1021 // get_cout_stream() << "gmm_diag::internal_gen_boundaries(): n_threads: " << n_threads << '\n';
1022
1023 umat boundaries(2, n_threads, arma_nozeros_indicator());
1024
1025 if(N > 0)
1026 {
1027 const uword chunk_size = N / n_threads;
1028
1029 uword count = 0;
1030
1031 for(uword t=0; t<n_threads; t++)
1032 {
1033 boundaries.at(0,t) = count;
1034
1035 count += chunk_size;
1036
1037 boundaries.at(1,t) = count-1;
1038 }
1039
1040 boundaries.at(1,n_threads-1) = N - 1;
1041 }
1042 else
1043 {
1044 boundaries.zeros();
1045 }
1046
1047 // get_cout_stream() << "gmm_diag::internal_gen_boundaries(): boundaries: " << '\n' << boundaries << '\n';
1048
1049 return boundaries;
1050 }
1051
1052
1053
1054 template<typename eT>
1055 arma_hot
1056 inline
1057 eT
internal_scalar_log_p(const eT * x) const1058 gmm_diag<eT>::internal_scalar_log_p(const eT* x) const
1059 {
1060 arma_extra_debug_sigprint();
1061
1062 const eT* log_hefts_mem = log_hefts.mem;
1063
1064 const uword N_gaus = means.n_cols;
1065
1066 if(N_gaus > 0)
1067 {
1068 eT log_sum = internal_scalar_log_p(x, 0) + log_hefts_mem[0];
1069
1070 for(uword g=1; g < N_gaus; ++g)
1071 {
1072 const eT tmp = internal_scalar_log_p(x, g) + log_hefts_mem[g];
1073
1074 log_sum = log_add_exp(log_sum, tmp);
1075 }
1076
1077 return log_sum;
1078 }
1079 else
1080 {
1081 return -Datum<eT>::inf;
1082 }
1083 }
1084
1085
1086
1087 template<typename eT>
1088 arma_hot
1089 inline
1090 eT
internal_scalar_log_p(const eT * x,const uword g) const1091 gmm_diag<eT>::internal_scalar_log_p(const eT* x, const uword g) const
1092 {
1093 arma_extra_debug_sigprint();
1094
1095 const eT* mean = means.colptr(g);
1096 const eT* inv_dcov = inv_dcovs.colptr(g);
1097
1098 const uword N_dims = means.n_rows;
1099
1100 eT val_i = eT(0);
1101 eT val_j = eT(0);
1102
1103 uword i,j;
1104
1105 for(i=0, j=1; j<N_dims; i+=2, j+=2)
1106 {
1107 eT tmp_i = x[i];
1108 eT tmp_j = x[j];
1109
1110 tmp_i -= mean[i];
1111 tmp_j -= mean[j];
1112
1113 val_i += (tmp_i*tmp_i) * inv_dcov[i];
1114 val_j += (tmp_j*tmp_j) * inv_dcov[j];
1115 }
1116
1117 if(i < N_dims)
1118 {
1119 const eT tmp = x[i] - mean[i];
1120
1121 val_i += (tmp*tmp) * inv_dcov[i];
1122 }
1123
1124 return eT(-0.5)*(val_i + val_j) + log_det_etc.mem[g];
1125 }
1126
1127
1128
1129 template<typename eT>
1130 inline
1131 Row<eT>
internal_vec_log_p(const Mat<eT> & X) const1132 gmm_diag<eT>::internal_vec_log_p(const Mat<eT>& X) const
1133 {
1134 arma_extra_debug_sigprint();
1135
1136 arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::log_p(): incompatible dimensions" );
1137
1138 const uword N = X.n_cols;
1139
1140 Row<eT> out(N, arma_nozeros_indicator());
1141
1142 if(N > 0)
1143 {
1144 #if defined(ARMA_USE_OPENMP)
1145 {
1146 const umat boundaries = internal_gen_boundaries(N);
1147
1148 const uword n_threads = boundaries.n_cols;
1149
1150 #pragma omp parallel for schedule(static)
1151 for(uword t=0; t < n_threads; ++t)
1152 {
1153 const uword start_index = boundaries.at(0,t);
1154 const uword end_index = boundaries.at(1,t);
1155
1156 eT* out_mem = out.memptr();
1157
1158 for(uword i=start_index; i <= end_index; ++i)
1159 {
1160 out_mem[i] = internal_scalar_log_p( X.colptr(i) );
1161 }
1162 }
1163 }
1164 #else
1165 {
1166 eT* out_mem = out.memptr();
1167
1168 for(uword i=0; i < N; ++i)
1169 {
1170 out_mem[i] = internal_scalar_log_p( X.colptr(i) );
1171 }
1172 }
1173 #endif
1174 }
1175
1176 return out;
1177 }
1178
1179
1180
1181 template<typename eT>
1182 inline
1183 Row<eT>
internal_vec_log_p(const Mat<eT> & X,const uword gaus_id) const1184 gmm_diag<eT>::internal_vec_log_p(const Mat<eT>& X, const uword gaus_id) const
1185 {
1186 arma_extra_debug_sigprint();
1187
1188 arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::log_p(): incompatible dimensions" );
1189 arma_debug_check( (gaus_id >= means.n_cols), "gmm_diag::log_p(): specified gaussian is out of range" );
1190
1191 const uword N = X.n_cols;
1192
1193 Row<eT> out(N, arma_nozeros_indicator());
1194
1195 if(N > 0)
1196 {
1197 #if defined(ARMA_USE_OPENMP)
1198 {
1199 const umat boundaries = internal_gen_boundaries(N);
1200
1201 const uword n_threads = boundaries.n_cols;
1202
1203 #pragma omp parallel for schedule(static)
1204 for(uword t=0; t < n_threads; ++t)
1205 {
1206 const uword start_index = boundaries.at(0,t);
1207 const uword end_index = boundaries.at(1,t);
1208
1209 eT* out_mem = out.memptr();
1210
1211 for(uword i=start_index; i <= end_index; ++i)
1212 {
1213 out_mem[i] = internal_scalar_log_p( X.colptr(i), gaus_id );
1214 }
1215 }
1216 }
1217 #else
1218 {
1219 eT* out_mem = out.memptr();
1220
1221 for(uword i=0; i < N; ++i)
1222 {
1223 out_mem[i] = internal_scalar_log_p( X.colptr(i), gaus_id );
1224 }
1225 }
1226 #endif
1227 }
1228
1229 return out;
1230 }
1231
1232
1233
1234 template<typename eT>
1235 inline
1236 eT
internal_sum_log_p(const Mat<eT> & X) const1237 gmm_diag<eT>::internal_sum_log_p(const Mat<eT>& X) const
1238 {
1239 arma_extra_debug_sigprint();
1240
1241 arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::sum_log_p(): incompatible dimensions" );
1242
1243 const uword N = X.n_cols;
1244
1245 if(N == 0) { return (-Datum<eT>::inf); }
1246
1247
1248 #if defined(ARMA_USE_OPENMP)
1249 {
1250 const umat boundaries = internal_gen_boundaries(N);
1251
1252 const uword n_threads = boundaries.n_cols;
1253
1254 Col<eT> t_accs(n_threads, arma_zeros_indicator());
1255
1256 #pragma omp parallel for schedule(static)
1257 for(uword t=0; t < n_threads; ++t)
1258 {
1259 const uword start_index = boundaries.at(0,t);
1260 const uword end_index = boundaries.at(1,t);
1261
1262 eT t_acc = eT(0);
1263
1264 for(uword i=start_index; i <= end_index; ++i)
1265 {
1266 t_acc += internal_scalar_log_p( X.colptr(i) );
1267 }
1268
1269 t_accs[t] = t_acc;
1270 }
1271
1272 return eT(accu(t_accs));
1273 }
1274 #else
1275 {
1276 eT acc = eT(0);
1277
1278 for(uword i=0; i<N; ++i)
1279 {
1280 acc += internal_scalar_log_p( X.colptr(i) );
1281 }
1282
1283 return acc;
1284 }
1285 #endif
1286 }
1287
1288
1289
1290 template<typename eT>
1291 inline
1292 eT
internal_sum_log_p(const Mat<eT> & X,const uword gaus_id) const1293 gmm_diag<eT>::internal_sum_log_p(const Mat<eT>& X, const uword gaus_id) const
1294 {
1295 arma_extra_debug_sigprint();
1296
1297 arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::sum_log_p(): incompatible dimensions" );
1298 arma_debug_check( (gaus_id >= means.n_cols), "gmm_diag::sum_log_p(): specified gaussian is out of range" );
1299
1300 const uword N = X.n_cols;
1301
1302 if(N == 0) { return (-Datum<eT>::inf); }
1303
1304
1305 #if defined(ARMA_USE_OPENMP)
1306 {
1307 const umat boundaries = internal_gen_boundaries(N);
1308
1309 const uword n_threads = boundaries.n_cols;
1310
1311 Col<eT> t_accs(n_threads, arma_zeros_indicator());
1312
1313 #pragma omp parallel for schedule(static)
1314 for(uword t=0; t < n_threads; ++t)
1315 {
1316 const uword start_index = boundaries.at(0,t);
1317 const uword end_index = boundaries.at(1,t);
1318
1319 eT t_acc = eT(0);
1320
1321 for(uword i=start_index; i <= end_index; ++i)
1322 {
1323 t_acc += internal_scalar_log_p( X.colptr(i), gaus_id );
1324 }
1325
1326 t_accs[t] = t_acc;
1327 }
1328
1329 return eT(accu(t_accs));
1330 }
1331 #else
1332 {
1333 eT acc = eT(0);
1334
1335 for(uword i=0; i<N; ++i)
1336 {
1337 acc += internal_scalar_log_p( X.colptr(i), gaus_id );
1338 }
1339
1340 return acc;
1341 }
1342 #endif
1343 }
1344
1345
1346
1347 template<typename eT>
1348 inline
1349 eT
internal_avg_log_p(const Mat<eT> & X) const1350 gmm_diag<eT>::internal_avg_log_p(const Mat<eT>& X) const
1351 {
1352 arma_extra_debug_sigprint();
1353
1354 arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::avg_log_p(): incompatible dimensions" );
1355
1356 const uword N = X.n_cols;
1357
1358 if(N == 0) { return (-Datum<eT>::inf); }
1359
1360
1361 #if defined(ARMA_USE_OPENMP)
1362 {
1363 const umat boundaries = internal_gen_boundaries(N);
1364
1365 const uword n_threads = boundaries.n_cols;
1366
1367 field< running_mean_scalar<eT> > t_running_means(n_threads);
1368
1369
1370 #pragma omp parallel for schedule(static)
1371 for(uword t=0; t < n_threads; ++t)
1372 {
1373 const uword start_index = boundaries.at(0,t);
1374 const uword end_index = boundaries.at(1,t);
1375
1376 running_mean_scalar<eT>& current_running_mean = t_running_means[t];
1377
1378 for(uword i=start_index; i <= end_index; ++i)
1379 {
1380 current_running_mean( internal_scalar_log_p( X.colptr(i) ) );
1381 }
1382 }
1383
1384
1385 eT avg = eT(0);
1386
1387 for(uword t=0; t < n_threads; ++t)
1388 {
1389 running_mean_scalar<eT>& current_running_mean = t_running_means[t];
1390
1391 const eT w = eT(current_running_mean.count()) / eT(N);
1392
1393 avg += w * current_running_mean.mean();
1394 }
1395
1396 return avg;
1397 }
1398 #else
1399 {
1400 running_mean_scalar<eT> running_mean;
1401
1402 for(uword i=0; i<N; ++i)
1403 {
1404 running_mean( internal_scalar_log_p( X.colptr(i) ) );
1405 }
1406
1407 return running_mean.mean();
1408 }
1409 #endif
1410 }
1411
1412
1413
1414 template<typename eT>
1415 inline
1416 eT
internal_avg_log_p(const Mat<eT> & X,const uword gaus_id) const1417 gmm_diag<eT>::internal_avg_log_p(const Mat<eT>& X, const uword gaus_id) const
1418 {
1419 arma_extra_debug_sigprint();
1420
1421 arma_debug_check( (X.n_rows != means.n_rows), "gmm_diag::avg_log_p(): incompatible dimensions" );
1422 arma_debug_check( (gaus_id >= means.n_cols), "gmm_diag::avg_log_p(): specified gaussian is out of range" );
1423
1424 const uword N = X.n_cols;
1425
1426 if(N == 0) { return (-Datum<eT>::inf); }
1427
1428
1429 #if defined(ARMA_USE_OPENMP)
1430 {
1431 const umat boundaries = internal_gen_boundaries(N);
1432
1433 const uword n_threads = boundaries.n_cols;
1434
1435 field< running_mean_scalar<eT> > t_running_means(n_threads);
1436
1437
1438 #pragma omp parallel for schedule(static)
1439 for(uword t=0; t < n_threads; ++t)
1440 {
1441 const uword start_index = boundaries.at(0,t);
1442 const uword end_index = boundaries.at(1,t);
1443
1444 running_mean_scalar<eT>& current_running_mean = t_running_means[t];
1445
1446 for(uword i=start_index; i <= end_index; ++i)
1447 {
1448 current_running_mean( internal_scalar_log_p( X.colptr(i), gaus_id) );
1449 }
1450 }
1451
1452
1453 eT avg = eT(0);
1454
1455 for(uword t=0; t < n_threads; ++t)
1456 {
1457 running_mean_scalar<eT>& current_running_mean = t_running_means[t];
1458
1459 const eT w = eT(current_running_mean.count()) / eT(N);
1460
1461 avg += w * current_running_mean.mean();
1462 }
1463
1464 return avg;
1465 }
1466 #else
1467 {
1468 running_mean_scalar<eT> running_mean;
1469
1470 for(uword i=0; i<N; ++i)
1471 {
1472 running_mean( internal_scalar_log_p( X.colptr(i), gaus_id ) );
1473 }
1474
1475 return running_mean.mean();
1476 }
1477 #endif
1478 }
1479
1480
1481
1482 template<typename eT>
1483 inline
1484 uword
internal_scalar_assign(const Mat<eT> & X,const gmm_dist_mode & dist_mode) const1485 gmm_diag<eT>::internal_scalar_assign(const Mat<eT>& X, const gmm_dist_mode& dist_mode) const
1486 {
1487 arma_extra_debug_sigprint();
1488
1489 const uword N_dims = means.n_rows;
1490 const uword N_gaus = means.n_cols;
1491
1492 arma_debug_check( (X.n_rows != N_dims), "gmm_diag::assign(): incompatible dimensions" );
1493 arma_debug_check( (N_gaus == 0), "gmm_diag::assign(): model has no means" );
1494
1495 const eT* X_mem = X.colptr(0);
1496
1497 if(dist_mode == eucl_dist)
1498 {
1499 eT best_dist = Datum<eT>::inf;
1500 uword best_g = 0;
1501
1502 for(uword g=0; g < N_gaus; ++g)
1503 {
1504 const eT tmp_dist = distance<eT,1>::eval(N_dims, X_mem, means.colptr(g), X_mem);
1505
1506 if(tmp_dist <= best_dist) { best_dist = tmp_dist; best_g = g; }
1507 }
1508
1509 return best_g;
1510 }
1511 else
1512 if(dist_mode == prob_dist)
1513 {
1514 const eT* log_hefts_mem = log_hefts.memptr();
1515
1516 eT best_p = -Datum<eT>::inf;
1517 uword best_g = 0;
1518
1519 for(uword g=0; g < N_gaus; ++g)
1520 {
1521 const eT tmp_p = internal_scalar_log_p(X_mem, g) + log_hefts_mem[g];
1522
1523 if(tmp_p >= best_p) { best_p = tmp_p; best_g = g; }
1524 }
1525
1526 return best_g;
1527 }
1528 else
1529 {
1530 arma_debug_check(true, "gmm_diag::assign(): unsupported distance mode");
1531 }
1532
1533 return uword(0);
1534 }
1535
1536
1537
1538 template<typename eT>
1539 inline
1540 void
internal_vec_assign(urowvec & out,const Mat<eT> & X,const gmm_dist_mode & dist_mode) const1541 gmm_diag<eT>::internal_vec_assign(urowvec& out, const Mat<eT>& X, const gmm_dist_mode& dist_mode) const
1542 {
1543 arma_extra_debug_sigprint();
1544
1545 const uword N_dims = means.n_rows;
1546 const uword N_gaus = means.n_cols;
1547
1548 arma_debug_check( (X.n_rows != N_dims), "gmm_diag::assign(): incompatible dimensions" );
1549
1550 const uword X_n_cols = (N_gaus > 0) ? X.n_cols : 0;
1551
1552 out.set_size(1,X_n_cols);
1553
1554 uword* out_mem = out.memptr();
1555
1556 if(dist_mode == eucl_dist)
1557 {
1558 #if defined(ARMA_USE_OPENMP)
1559 {
1560 #pragma omp parallel for schedule(static)
1561 for(uword i=0; i<X_n_cols; ++i)
1562 {
1563 const eT* X_colptr = X.colptr(i);
1564
1565 eT best_dist = Datum<eT>::inf;
1566 uword best_g = 0;
1567
1568 for(uword g=0; g<N_gaus; ++g)
1569 {
1570 const eT tmp_dist = distance<eT,1>::eval(N_dims, X_colptr, means.colptr(g), X_colptr);
1571
1572 if(tmp_dist <= best_dist) { best_dist = tmp_dist; best_g = g; }
1573 }
1574
1575 out_mem[i] = best_g;
1576 }
1577 }
1578 #else
1579 {
1580 for(uword i=0; i<X_n_cols; ++i)
1581 {
1582 const eT* X_colptr = X.colptr(i);
1583
1584 eT best_dist = Datum<eT>::inf;
1585 uword best_g = 0;
1586
1587 for(uword g=0; g<N_gaus; ++g)
1588 {
1589 const eT tmp_dist = distance<eT,1>::eval(N_dims, X_colptr, means.colptr(g), X_colptr);
1590
1591 if(tmp_dist <= best_dist) { best_dist = tmp_dist; best_g = g; }
1592 }
1593
1594 out_mem[i] = best_g;
1595 }
1596 }
1597 #endif
1598 }
1599 else
1600 if(dist_mode == prob_dist)
1601 {
1602 #if defined(ARMA_USE_OPENMP)
1603 {
1604 const eT* log_hefts_mem = log_hefts.memptr();
1605
1606 #pragma omp parallel for schedule(static)
1607 for(uword i=0; i<X_n_cols; ++i)
1608 {
1609 const eT* X_colptr = X.colptr(i);
1610
1611 eT best_p = -Datum<eT>::inf;
1612 uword best_g = 0;
1613
1614 for(uword g=0; g<N_gaus; ++g)
1615 {
1616 const eT tmp_p = internal_scalar_log_p(X_colptr, g) + log_hefts_mem[g];
1617
1618 if(tmp_p >= best_p) { best_p = tmp_p; best_g = g; }
1619 }
1620
1621 out_mem[i] = best_g;
1622 }
1623 }
1624 #else
1625 {
1626 const eT* log_hefts_mem = log_hefts.memptr();
1627
1628 for(uword i=0; i<X_n_cols; ++i)
1629 {
1630 const eT* X_colptr = X.colptr(i);
1631
1632 eT best_p = -Datum<eT>::inf;
1633 uword best_g = 0;
1634
1635 for(uword g=0; g<N_gaus; ++g)
1636 {
1637 const eT tmp_p = internal_scalar_log_p(X_colptr, g) + log_hefts_mem[g];
1638
1639 if(tmp_p >= best_p) { best_p = tmp_p; best_g = g; }
1640 }
1641
1642 out_mem[i] = best_g;
1643 }
1644 }
1645 #endif
1646 }
1647 else
1648 {
1649 arma_debug_check(true, "gmm_diag::assign(): unsupported distance mode");
1650 }
1651 }
1652
1653
1654
1655
1656 template<typename eT>
1657 inline
1658 void
internal_raw_hist(urowvec & hist,const Mat<eT> & X,const gmm_dist_mode & dist_mode) const1659 gmm_diag<eT>::internal_raw_hist(urowvec& hist, const Mat<eT>& X, const gmm_dist_mode& dist_mode) const
1660 {
1661 arma_extra_debug_sigprint();
1662
1663 const uword N_dims = means.n_rows;
1664 const uword N_gaus = means.n_cols;
1665
1666 const uword X_n_cols = X.n_cols;
1667
1668 hist.zeros(N_gaus);
1669
1670 if(N_gaus == 0) { return; }
1671
1672 #if defined(ARMA_USE_OPENMP)
1673 {
1674 const umat boundaries = internal_gen_boundaries(X_n_cols);
1675
1676 const uword n_threads = boundaries.n_cols;
1677
1678 field<urowvec> thread_hist(n_threads);
1679
1680 for(uword t=0; t < n_threads; ++t) { thread_hist(t).zeros(N_gaus); }
1681
1682
1683 if(dist_mode == eucl_dist)
1684 {
1685 #pragma omp parallel for schedule(static)
1686 for(uword t=0; t < n_threads; ++t)
1687 {
1688 uword* thread_hist_mem = thread_hist(t).memptr();
1689
1690 const uword start_index = boundaries.at(0,t);
1691 const uword end_index = boundaries.at(1,t);
1692
1693 for(uword i=start_index; i <= end_index; ++i)
1694 {
1695 const eT* X_colptr = X.colptr(i);
1696
1697 eT best_dist = Datum<eT>::inf;
1698 uword best_g = 0;
1699
1700 for(uword g=0; g < N_gaus; ++g)
1701 {
1702 const eT tmp_dist = distance<eT,1>::eval(N_dims, X_colptr, means.colptr(g), X_colptr);
1703
1704 if(tmp_dist <= best_dist) { best_dist = tmp_dist; best_g = g; }
1705 }
1706
1707 thread_hist_mem[best_g]++;
1708 }
1709 }
1710 }
1711 else
1712 if(dist_mode == prob_dist)
1713 {
1714 const eT* log_hefts_mem = log_hefts.memptr();
1715
1716 #pragma omp parallel for schedule(static)
1717 for(uword t=0; t < n_threads; ++t)
1718 {
1719 uword* thread_hist_mem = thread_hist(t).memptr();
1720
1721 const uword start_index = boundaries.at(0,t);
1722 const uword end_index = boundaries.at(1,t);
1723
1724 for(uword i=start_index; i <= end_index; ++i)
1725 {
1726 const eT* X_colptr = X.colptr(i);
1727
1728 eT best_p = -Datum<eT>::inf;
1729 uword best_g = 0;
1730
1731 for(uword g=0; g < N_gaus; ++g)
1732 {
1733 const eT tmp_p = internal_scalar_log_p(X_colptr, g) + log_hefts_mem[g];
1734
1735 if(tmp_p >= best_p) { best_p = tmp_p; best_g = g; }
1736 }
1737
1738 thread_hist_mem[best_g]++;
1739 }
1740 }
1741 }
1742
1743 // reduction
1744 hist = thread_hist(0);
1745
1746 for(uword t=1; t < n_threads; ++t)
1747 {
1748 hist += thread_hist(t);
1749 }
1750 }
1751 #else
1752 {
1753 uword* hist_mem = hist.memptr();
1754
1755 if(dist_mode == eucl_dist)
1756 {
1757 for(uword i=0; i<X_n_cols; ++i)
1758 {
1759 const eT* X_colptr = X.colptr(i);
1760
1761 eT best_dist = Datum<eT>::inf;
1762 uword best_g = 0;
1763
1764 for(uword g=0; g < N_gaus; ++g)
1765 {
1766 const eT tmp_dist = distance<eT,1>::eval(N_dims, X_colptr, means.colptr(g), X_colptr);
1767
1768 if(tmp_dist <= best_dist) { best_dist = tmp_dist; best_g = g; }
1769 }
1770
1771 hist_mem[best_g]++;
1772 }
1773 }
1774 else
1775 if(dist_mode == prob_dist)
1776 {
1777 const eT* log_hefts_mem = log_hefts.memptr();
1778
1779 for(uword i=0; i<X_n_cols; ++i)
1780 {
1781 const eT* X_colptr = X.colptr(i);
1782
1783 eT best_p = -Datum<eT>::inf;
1784 uword best_g = 0;
1785
1786 for(uword g=0; g < N_gaus; ++g)
1787 {
1788 const eT tmp_p = internal_scalar_log_p(X_colptr, g) + log_hefts_mem[g];
1789
1790 if(tmp_p >= best_p) { best_p = tmp_p; best_g = g; }
1791 }
1792
1793 hist_mem[best_g]++;
1794 }
1795 }
1796 }
1797 #endif
1798 }
1799
1800
1801
1802 template<typename eT>
1803 template<uword dist_id>
1804 inline
1805 void
generate_initial_means(const Mat<eT> & X,const gmm_seed_mode & seed_mode)1806 gmm_diag<eT>::generate_initial_means(const Mat<eT>& X, const gmm_seed_mode& seed_mode)
1807 {
1808 arma_extra_debug_sigprint();
1809
1810 const uword N_dims = means.n_rows;
1811 const uword N_gaus = means.n_cols;
1812
1813 if( (seed_mode == static_subset) || (seed_mode == random_subset) )
1814 {
1815 uvec initial_indices;
1816
1817 if(seed_mode == static_subset) { initial_indices = linspace<uvec>(0, X.n_cols-1, N_gaus); }
1818 else if(seed_mode == random_subset) { initial_indices = randperm<uvec>(X.n_cols, N_gaus); }
1819
1820 // initial_indices.print("initial_indices:");
1821
1822 access::rw(means) = X.cols(initial_indices);
1823 }
1824 else
1825 if( (seed_mode == static_spread) || (seed_mode == random_spread) )
1826 {
1827 // going through all of the samples can be extremely time consuming;
1828 // instead, if there are enough samples, randomly choose samples with probability 0.1
1829
1830 const bool use_sampling = ((X.n_cols/uword(100)) > N_gaus);
1831 const uword step = (use_sampling) ? uword(10) : uword(1);
1832
1833 uword start_index = 0;
1834
1835 if(seed_mode == static_spread) { start_index = X.n_cols / 2; }
1836 else if(seed_mode == random_spread) { start_index = as_scalar(randi<uvec>(1, distr_param(0,X.n_cols-1))); }
1837
1838 access::rw(means).col(0) = X.unsafe_col(start_index);
1839
1840 const eT* mah_aux_mem = mah_aux.memptr();
1841
1842 running_stat<double> rs;
1843
1844 for(uword g=1; g < N_gaus; ++g)
1845 {
1846 eT max_dist = eT(0);
1847 uword best_i = uword(0);
1848 uword start_i = uword(0);
1849
1850 if(use_sampling)
1851 {
1852 uword start_i_proposed = uword(0);
1853
1854 if(seed_mode == static_spread) { start_i_proposed = g % uword(10); }
1855 if(seed_mode == random_spread) { start_i_proposed = as_scalar(randi<uvec>(1, distr_param(0,9))); }
1856
1857 if(start_i_proposed < X.n_cols) { start_i = start_i_proposed; }
1858 }
1859
1860
1861 for(uword i=start_i; i < X.n_cols; i += step)
1862 {
1863 rs.reset();
1864
1865 const eT* X_colptr = X.colptr(i);
1866
1867 bool ignore_i = false;
1868
1869 // find the average distance between sample i and the means so far
1870 for(uword h = 0; h < g; ++h)
1871 {
1872 const eT dist = distance<eT,dist_id>::eval(N_dims, X_colptr, means.colptr(h), mah_aux_mem);
1873
1874 // ignore sample already selected as a mean
1875 if(dist == eT(0)) { ignore_i = true; break; }
1876 else { rs(dist); }
1877 }
1878
1879 if( (rs.mean() >= max_dist) && (ignore_i == false))
1880 {
1881 max_dist = eT(rs.mean()); best_i = i;
1882 }
1883 }
1884
1885 // set the mean to the sample that is the furthest away from the means so far
1886 access::rw(means).col(g) = X.unsafe_col(best_i);
1887 }
1888 }
1889
1890 // get_cout_stream() << "generate_initial_means():" << '\n';
1891 // means.print();
1892 }
1893
1894
1895
1896 template<typename eT>
1897 template<uword dist_id>
1898 inline
1899 void
generate_initial_params(const Mat<eT> & X,const eT var_floor)1900 gmm_diag<eT>::generate_initial_params(const Mat<eT>& X, const eT var_floor)
1901 {
1902 arma_extra_debug_sigprint();
1903
1904 const uword N_dims = means.n_rows;
1905 const uword N_gaus = means.n_cols;
1906
1907 const eT* mah_aux_mem = mah_aux.memptr();
1908
1909 const uword X_n_cols = X.n_cols;
1910
1911 if(X_n_cols == 0) { return; }
1912
1913 // as the covariances are calculated via accumulators,
1914 // the means also need to be calculated via accumulators to ensure numerical consistency
1915
1916 Mat<eT> acc_means(N_dims, N_gaus, arma_zeros_indicator());
1917 Mat<eT> acc_dcovs(N_dims, N_gaus, arma_zeros_indicator());
1918
1919 Row<uword> acc_hefts(N_gaus, arma_zeros_indicator());
1920
1921 uword* acc_hefts_mem = acc_hefts.memptr();
1922
1923 #if defined(ARMA_USE_OPENMP)
1924 {
1925 const umat boundaries = internal_gen_boundaries(X_n_cols);
1926
1927 const uword n_threads = boundaries.n_cols;
1928
1929 field< Mat<eT> > t_acc_means(n_threads);
1930 field< Mat<eT> > t_acc_dcovs(n_threads);
1931 field< Row<uword> > t_acc_hefts(n_threads);
1932
1933 for(uword t=0; t < n_threads; ++t)
1934 {
1935 t_acc_means(t).zeros(N_dims, N_gaus);
1936 t_acc_dcovs(t).zeros(N_dims, N_gaus);
1937 t_acc_hefts(t).zeros(N_gaus);
1938 }
1939
1940 #pragma omp parallel for schedule(static)
1941 for(uword t=0; t < n_threads; ++t)
1942 {
1943 uword* t_acc_hefts_mem = t_acc_hefts(t).memptr();
1944
1945 const uword start_index = boundaries.at(0,t);
1946 const uword end_index = boundaries.at(1,t);
1947
1948 for(uword i=start_index; i <= end_index; ++i)
1949 {
1950 const eT* X_colptr = X.colptr(i);
1951
1952 eT min_dist = Datum<eT>::inf;
1953 uword best_g = 0;
1954
1955 for(uword g=0; g<N_gaus; ++g)
1956 {
1957 const eT dist = distance<eT,dist_id>::eval(N_dims, X_colptr, means.colptr(g), mah_aux_mem);
1958
1959 if(dist < min_dist) { min_dist = dist; best_g = g; }
1960 }
1961
1962 eT* t_acc_mean = t_acc_means(t).colptr(best_g);
1963 eT* t_acc_dcov = t_acc_dcovs(t).colptr(best_g);
1964
1965 for(uword d=0; d<N_dims; ++d)
1966 {
1967 const eT x_d = X_colptr[d];
1968
1969 t_acc_mean[d] += x_d;
1970 t_acc_dcov[d] += x_d*x_d;
1971 }
1972
1973 t_acc_hefts_mem[best_g]++;
1974 }
1975 }
1976
1977 // reduction
1978 acc_means = t_acc_means(0);
1979 acc_dcovs = t_acc_dcovs(0);
1980 acc_hefts = t_acc_hefts(0);
1981
1982 for(uword t=1; t < n_threads; ++t)
1983 {
1984 acc_means += t_acc_means(t);
1985 acc_dcovs += t_acc_dcovs(t);
1986 acc_hefts += t_acc_hefts(t);
1987 }
1988 }
1989 #else
1990 {
1991 for(uword i=0; i<X_n_cols; ++i)
1992 {
1993 const eT* X_colptr = X.colptr(i);
1994
1995 eT min_dist = Datum<eT>::inf;
1996 uword best_g = 0;
1997
1998 for(uword g=0; g<N_gaus; ++g)
1999 {
2000 const eT dist = distance<eT,dist_id>::eval(N_dims, X_colptr, means.colptr(g), mah_aux_mem);
2001
2002 if(dist < min_dist) { min_dist = dist; best_g = g; }
2003 }
2004
2005 eT* acc_mean = acc_means.colptr(best_g);
2006 eT* acc_dcov = acc_dcovs.colptr(best_g);
2007
2008 for(uword d=0; d<N_dims; ++d)
2009 {
2010 const eT x_d = X_colptr[d];
2011
2012 acc_mean[d] += x_d;
2013 acc_dcov[d] += x_d*x_d;
2014 }
2015
2016 acc_hefts_mem[best_g]++;
2017 }
2018 }
2019 #endif
2020
2021 eT* hefts_mem = access::rw(hefts).memptr();
2022
2023 for(uword g=0; g<N_gaus; ++g)
2024 {
2025 const eT* acc_mean = acc_means.colptr(g);
2026 const eT* acc_dcov = acc_dcovs.colptr(g);
2027 const uword acc_heft = acc_hefts_mem[g];
2028
2029 eT* mean = access::rw(means).colptr(g);
2030 eT* dcov = access::rw(dcovs).colptr(g);
2031
2032 for(uword d=0; d<N_dims; ++d)
2033 {
2034 const eT tmp = acc_mean[d] / eT(acc_heft);
2035
2036 mean[d] = (acc_heft >= 1) ? tmp : eT(0);
2037 dcov[d] = (acc_heft >= 2) ? eT((acc_dcov[d] / eT(acc_heft)) - (tmp*tmp)) : eT(var_floor);
2038 }
2039
2040 hefts_mem[g] = eT(acc_heft) / eT(X_n_cols);
2041 }
2042
2043 em_fix_params(var_floor);
2044 }
2045
2046
2047
2048 //! multi-threaded implementation of k-means, inspired by MapReduce
2049 template<typename eT>
2050 template<uword dist_id>
2051 inline
2052 bool
km_iterate(const Mat<eT> & X,const uword max_iter,const bool verbose,const char * signature)2053 gmm_diag<eT>::km_iterate(const Mat<eT>& X, const uword max_iter, const bool verbose, const char* signature)
2054 {
2055 arma_extra_debug_sigprint();
2056
2057 if(verbose)
2058 {
2059 get_cout_stream().unsetf(ios::showbase);
2060 get_cout_stream().unsetf(ios::uppercase);
2061 get_cout_stream().unsetf(ios::showpos);
2062 get_cout_stream().unsetf(ios::scientific);
2063
2064 get_cout_stream().setf(ios::right);
2065 get_cout_stream().setf(ios::fixed);
2066 }
2067
2068 const uword X_n_cols = X.n_cols;
2069
2070 if(X_n_cols == 0) { return true; }
2071
2072 const uword N_dims = means.n_rows;
2073 const uword N_gaus = means.n_cols;
2074
2075 const eT* mah_aux_mem = mah_aux.memptr();
2076
2077 Mat<eT> acc_means(N_dims, N_gaus, arma_zeros_indicator());
2078 Row<uword> acc_hefts( N_gaus, arma_zeros_indicator());
2079 Row<uword> last_indx( N_gaus, arma_zeros_indicator());
2080
2081 Mat<eT> new_means = means;
2082 Mat<eT> old_means = means;
2083
2084 running_mean_scalar<eT> rs_delta;
2085
2086 #if defined(ARMA_USE_OPENMP)
2087 const umat boundaries = internal_gen_boundaries(X_n_cols);
2088 const uword n_threads = boundaries.n_cols;
2089
2090 field< Mat<eT> > t_acc_means(n_threads);
2091 field< Row<uword> > t_acc_hefts(n_threads);
2092 field< Row<uword> > t_last_indx(n_threads);
2093 #else
2094 const uword n_threads = 1;
2095 #endif
2096
2097 if(verbose) { get_cout_stream() << signature << ": n_threads: " << n_threads << '\n'; get_cout_stream().flush(); }
2098
2099 for(uword iter=1; iter <= max_iter; ++iter)
2100 {
2101 #if defined(ARMA_USE_OPENMP)
2102 {
2103 for(uword t=0; t < n_threads; ++t)
2104 {
2105 t_acc_means(t).zeros(N_dims, N_gaus);
2106 t_acc_hefts(t).zeros(N_gaus);
2107 t_last_indx(t).zeros(N_gaus);
2108 }
2109
2110 #pragma omp parallel for schedule(static)
2111 for(uword t=0; t < n_threads; ++t)
2112 {
2113 Mat<eT>& t_acc_means_t = t_acc_means(t);
2114 uword* t_acc_hefts_mem = t_acc_hefts(t).memptr();
2115 uword* t_last_indx_mem = t_last_indx(t).memptr();
2116
2117 const uword start_index = boundaries.at(0,t);
2118 const uword end_index = boundaries.at(1,t);
2119
2120 for(uword i=start_index; i <= end_index; ++i)
2121 {
2122 const eT* X_colptr = X.colptr(i);
2123
2124 eT min_dist = Datum<eT>::inf;
2125 uword best_g = 0;
2126
2127 for(uword g=0; g<N_gaus; ++g)
2128 {
2129 const eT dist = distance<eT,dist_id>::eval(N_dims, X_colptr, old_means.colptr(g), mah_aux_mem);
2130
2131 if(dist < min_dist) { min_dist = dist; best_g = g; }
2132 }
2133
2134 eT* t_acc_mean = t_acc_means_t.colptr(best_g);
2135
2136 for(uword d=0; d<N_dims; ++d) { t_acc_mean[d] += X_colptr[d]; }
2137
2138 t_acc_hefts_mem[best_g]++;
2139 t_last_indx_mem[best_g] = i;
2140 }
2141 }
2142
2143 // reduction
2144
2145 acc_means = t_acc_means(0);
2146 acc_hefts = t_acc_hefts(0);
2147
2148 for(uword t=1; t < n_threads; ++t)
2149 {
2150 acc_means += t_acc_means(t);
2151 acc_hefts += t_acc_hefts(t);
2152 }
2153
2154 for(uword g=0; g < N_gaus; ++g)
2155 for(uword t=0; t < n_threads; ++t)
2156 {
2157 if( t_acc_hefts(t)(g) >= 1 ) { last_indx(g) = t_last_indx(t)(g); }
2158 }
2159 }
2160 #else
2161 {
2162 uword* acc_hefts_mem = acc_hefts.memptr();
2163 uword* last_indx_mem = last_indx.memptr();
2164
2165 for(uword i=0; i < X_n_cols; ++i)
2166 {
2167 const eT* X_colptr = X.colptr(i);
2168
2169 eT min_dist = Datum<eT>::inf;
2170 uword best_g = 0;
2171
2172 for(uword g=0; g<N_gaus; ++g)
2173 {
2174 const eT dist = distance<eT,dist_id>::eval(N_dims, X_colptr, old_means.colptr(g), mah_aux_mem);
2175
2176 if(dist < min_dist) { min_dist = dist; best_g = g; }
2177 }
2178
2179 eT* acc_mean = acc_means.colptr(best_g);
2180
2181 for(uword d=0; d<N_dims; ++d) { acc_mean[d] += X_colptr[d]; }
2182
2183 acc_hefts_mem[best_g]++;
2184 last_indx_mem[best_g] = i;
2185 }
2186 }
2187 #endif
2188
2189 // generate new means
2190
2191 uword* acc_hefts_mem = acc_hefts.memptr();
2192
2193 for(uword g=0; g < N_gaus; ++g)
2194 {
2195 const eT* acc_mean = acc_means.colptr(g);
2196 const uword acc_heft = acc_hefts_mem[g];
2197
2198 eT* new_mean = access::rw(new_means).colptr(g);
2199
2200 for(uword d=0; d<N_dims; ++d)
2201 {
2202 new_mean[d] = (acc_heft >= 1) ? (acc_mean[d] / eT(acc_heft)) : eT(0);
2203 }
2204 }
2205
2206
2207 // heuristics to resurrect dead means
2208
2209 const uvec dead_gs = find(acc_hefts == uword(0));
2210
2211 if(dead_gs.n_elem > 0)
2212 {
2213 if(verbose) { get_cout_stream() << signature << ": recovering from dead means\n"; get_cout_stream().flush(); }
2214
2215 uword* last_indx_mem = last_indx.memptr();
2216
2217 const uvec live_gs = sort( find(acc_hefts >= uword(2)), "descend" );
2218
2219 if(live_gs.n_elem == 0) { return false; }
2220
2221 uword live_gs_count = 0;
2222
2223 for(uword dead_gs_count = 0; dead_gs_count < dead_gs.n_elem; ++dead_gs_count)
2224 {
2225 const uword dead_g_id = dead_gs(dead_gs_count);
2226
2227 uword proposed_i = 0;
2228
2229 if(live_gs_count < live_gs.n_elem)
2230 {
2231 const uword live_g_id = live_gs(live_gs_count); ++live_gs_count;
2232
2233 if(live_g_id == dead_g_id) { return false; }
2234
2235 // recover by using a sample from a known good mean
2236 proposed_i = last_indx_mem[live_g_id];
2237 }
2238 else
2239 {
2240 // recover by using a randomly seleced sample (last resort)
2241 proposed_i = as_scalar(randi<uvec>(1, distr_param(0,X_n_cols-1)));
2242 }
2243
2244 if(proposed_i >= X_n_cols) { return false; }
2245
2246 new_means.col(dead_g_id) = X.col(proposed_i);
2247 }
2248 }
2249
2250 rs_delta.reset();
2251
2252 for(uword g=0; g < N_gaus; ++g)
2253 {
2254 rs_delta( distance<eT,dist_id>::eval(N_dims, old_means.colptr(g), new_means.colptr(g), mah_aux_mem) );
2255 }
2256
2257 if(verbose)
2258 {
2259 get_cout_stream() << signature << ": iteration: ";
2260 get_cout_stream().unsetf(ios::scientific);
2261 get_cout_stream().setf(ios::fixed);
2262 get_cout_stream().width(std::streamsize(4));
2263 get_cout_stream() << iter;
2264 get_cout_stream() << " delta: ";
2265 get_cout_stream().unsetf(ios::fixed);
2266 //get_cout_stream().setf(ios::scientific);
2267 get_cout_stream() << rs_delta.mean() << '\n';
2268 get_cout_stream().flush();
2269 }
2270
2271 arma::swap(old_means, new_means);
2272
2273 if(rs_delta.mean() <= Datum<eT>::eps) { break; }
2274 }
2275
2276 access::rw(means) = old_means;
2277
2278 if(means.is_finite() == false) { return false; }
2279
2280 return true;
2281 }
2282
2283
2284
2285 //! multi-threaded implementation of Expectation-Maximisation, inspired by MapReduce
2286 template<typename eT>
2287 inline
2288 bool
em_iterate(const Mat<eT> & X,const uword max_iter,const eT var_floor,const bool verbose)2289 gmm_diag<eT>::em_iterate(const Mat<eT>& X, const uword max_iter, const eT var_floor, const bool verbose)
2290 {
2291 arma_extra_debug_sigprint();
2292
2293 if(X.n_cols == 0) { return true; }
2294
2295 const uword N_dims = means.n_rows;
2296 const uword N_gaus = means.n_cols;
2297
2298 if(verbose)
2299 {
2300 get_cout_stream().unsetf(ios::showbase);
2301 get_cout_stream().unsetf(ios::uppercase);
2302 get_cout_stream().unsetf(ios::showpos);
2303 get_cout_stream().unsetf(ios::scientific);
2304
2305 get_cout_stream().setf(ios::right);
2306 get_cout_stream().setf(ios::fixed);
2307 }
2308
2309 const umat boundaries = internal_gen_boundaries(X.n_cols);
2310
2311 const uword n_threads = boundaries.n_cols;
2312
2313 field< Mat<eT> > t_acc_means(n_threads);
2314 field< Mat<eT> > t_acc_dcovs(n_threads);
2315
2316 field< Col<eT> > t_acc_norm_lhoods(n_threads);
2317 field< Col<eT> > t_gaus_log_lhoods(n_threads);
2318
2319 Col<eT> t_progress_log_lhood(n_threads, arma_nozeros_indicator());
2320
2321 for(uword t=0; t<n_threads; t++)
2322 {
2323 t_acc_means[t].set_size(N_dims, N_gaus);
2324 t_acc_dcovs[t].set_size(N_dims, N_gaus);
2325
2326 t_acc_norm_lhoods[t].set_size(N_gaus);
2327 t_gaus_log_lhoods[t].set_size(N_gaus);
2328 }
2329
2330
2331 if(verbose)
2332 {
2333 get_cout_stream() << "gmm_diag::learn(): EM: n_threads: " << n_threads << '\n';
2334 }
2335
2336 eT old_avg_log_p = -Datum<eT>::inf;
2337
2338 for(uword iter=1; iter <= max_iter; ++iter)
2339 {
2340 init_constants();
2341
2342 em_update_params(X, boundaries, t_acc_means, t_acc_dcovs, t_acc_norm_lhoods, t_gaus_log_lhoods, t_progress_log_lhood);
2343
2344 em_fix_params(var_floor);
2345
2346 const eT new_avg_log_p = accu(t_progress_log_lhood) / eT(t_progress_log_lhood.n_elem);
2347
2348 if(verbose)
2349 {
2350 get_cout_stream() << "gmm_diag::learn(): EM: iteration: ";
2351 get_cout_stream().unsetf(ios::scientific);
2352 get_cout_stream().setf(ios::fixed);
2353 get_cout_stream().width(std::streamsize(4));
2354 get_cout_stream() << iter;
2355 get_cout_stream() << " avg_log_p: ";
2356 get_cout_stream().unsetf(ios::fixed);
2357 //get_cout_stream().setf(ios::scientific);
2358 get_cout_stream() << new_avg_log_p << '\n';
2359 get_cout_stream().flush();
2360 }
2361
2362 if(arma_isfinite(new_avg_log_p) == false) { return false; }
2363
2364 if(std::abs(old_avg_log_p - new_avg_log_p) <= Datum<eT>::eps) { break; }
2365
2366
2367 old_avg_log_p = new_avg_log_p;
2368 }
2369
2370
2371 if(any(vectorise(dcovs) <= eT(0))) { return false; }
2372 if(means.is_finite() == false ) { return false; }
2373 if(dcovs.is_finite() == false ) { return false; }
2374 if(hefts.is_finite() == false ) { return false; }
2375
2376 return true;
2377 }
2378
2379
2380
2381
2382 template<typename eT>
2383 inline
2384 void
em_update_params(const Mat<eT> & X,const umat & boundaries,field<Mat<eT>> & t_acc_means,field<Mat<eT>> & t_acc_dcovs,field<Col<eT>> & t_acc_norm_lhoods,field<Col<eT>> & t_gaus_log_lhoods,Col<eT> & t_progress_log_lhood)2385 gmm_diag<eT>::em_update_params
2386 (
2387 const Mat<eT>& X,
2388 const umat& boundaries,
2389 field< Mat<eT> >& t_acc_means,
2390 field< Mat<eT> >& t_acc_dcovs,
2391 field< Col<eT> >& t_acc_norm_lhoods,
2392 field< Col<eT> >& t_gaus_log_lhoods,
2393 Col<eT>& t_progress_log_lhood
2394 )
2395 {
2396 arma_extra_debug_sigprint();
2397
2398 const uword n_threads = boundaries.n_cols;
2399
2400
2401 // em_generate_acc() is the "map" operation, which produces partial accumulators for means, diagonal covariances and hefts
2402
2403 #if defined(ARMA_USE_OPENMP)
2404 {
2405 #pragma omp parallel for schedule(static)
2406 for(uword t=0; t<n_threads; t++)
2407 {
2408 Mat<eT>& acc_means = t_acc_means[t];
2409 Mat<eT>& acc_dcovs = t_acc_dcovs[t];
2410 Col<eT>& acc_norm_lhoods = t_acc_norm_lhoods[t];
2411 Col<eT>& gaus_log_lhoods = t_gaus_log_lhoods[t];
2412 eT& progress_log_lhood = t_progress_log_lhood[t];
2413
2414 em_generate_acc(X, boundaries.at(0,t), boundaries.at(1,t), acc_means, acc_dcovs, acc_norm_lhoods, gaus_log_lhoods, progress_log_lhood);
2415 }
2416 }
2417 #else
2418 {
2419 em_generate_acc(X, boundaries.at(0,0), boundaries.at(1,0), t_acc_means[0], t_acc_dcovs[0], t_acc_norm_lhoods[0], t_gaus_log_lhoods[0], t_progress_log_lhood[0]);
2420 }
2421 #endif
2422
2423 const uword N_dims = means.n_rows;
2424 const uword N_gaus = means.n_cols;
2425
2426 Mat<eT>& final_acc_means = t_acc_means[0];
2427 Mat<eT>& final_acc_dcovs = t_acc_dcovs[0];
2428
2429 Col<eT>& final_acc_norm_lhoods = t_acc_norm_lhoods[0];
2430
2431
2432 // the "reduce" operation, which combines the partial accumulators produced by the separate threads
2433
2434 for(uword t=1; t<n_threads; t++)
2435 {
2436 final_acc_means += t_acc_means[t];
2437 final_acc_dcovs += t_acc_dcovs[t];
2438
2439 final_acc_norm_lhoods += t_acc_norm_lhoods[t];
2440 }
2441
2442
2443 eT* hefts_mem = access::rw(hefts).memptr();
2444
2445
2446 //// update each component without sanity checking
2447 //for(uword g=0; g < N_gaus; ++g)
2448 // {
2449 // const eT acc_norm_lhood = (std::max)( final_acc_norm_lhoods[g], std::numeric_limits<eT>::min() );
2450 //
2451 // eT* mean_mem = access::rw(means).colptr(g);
2452 // eT* dcov_mem = access::rw(dcovs).colptr(g);
2453 //
2454 // eT* acc_mean_mem = final_acc_means.colptr(g);
2455 // eT* acc_dcov_mem = final_acc_dcovs.colptr(g);
2456 //
2457 // hefts_mem[g] = acc_norm_lhood / eT(X.n_cols);
2458 //
2459 // for(uword d=0; d < N_dims; ++d)
2460 // {
2461 // const eT tmp = acc_mean_mem[d] / acc_norm_lhood;
2462 //
2463 // mean_mem[d] = tmp;
2464 // dcov_mem[d] = acc_dcov_mem[d] / acc_norm_lhood - tmp*tmp;
2465 // }
2466 // }
2467
2468
2469 // conditionally update each component; if only a subset of the hefts was updated, em_fix_params() will sanitise them
2470 for(uword g=0; g < N_gaus; ++g)
2471 {
2472 const eT acc_norm_lhood = (std::max)( final_acc_norm_lhoods[g], std::numeric_limits<eT>::min() );
2473
2474 if(arma_isfinite(acc_norm_lhood) == false) { continue; }
2475
2476 eT* acc_mean_mem = final_acc_means.colptr(g);
2477 eT* acc_dcov_mem = final_acc_dcovs.colptr(g);
2478
2479 bool ok = true;
2480
2481 for(uword d=0; d < N_dims; ++d)
2482 {
2483 const eT tmp1 = acc_mean_mem[d] / acc_norm_lhood;
2484 const eT tmp2 = acc_dcov_mem[d] / acc_norm_lhood - tmp1*tmp1;
2485
2486 acc_mean_mem[d] = tmp1;
2487 acc_dcov_mem[d] = tmp2;
2488
2489 if(arma_isfinite(tmp2) == false) { ok = false; }
2490 }
2491
2492
2493 if(ok)
2494 {
2495 hefts_mem[g] = acc_norm_lhood / eT(X.n_cols);
2496
2497 eT* mean_mem = access::rw(means).colptr(g);
2498 eT* dcov_mem = access::rw(dcovs).colptr(g);
2499
2500 for(uword d=0; d < N_dims; ++d)
2501 {
2502 mean_mem[d] = acc_mean_mem[d];
2503 dcov_mem[d] = acc_dcov_mem[d];
2504 }
2505 }
2506 }
2507 }
2508
2509
2510
2511 template<typename eT>
2512 inline
2513 void
em_generate_acc(const Mat<eT> & X,const uword start_index,const uword end_index,Mat<eT> & acc_means,Mat<eT> & acc_dcovs,Col<eT> & acc_norm_lhoods,Col<eT> & gaus_log_lhoods,eT & progress_log_lhood) const2514 gmm_diag<eT>::em_generate_acc
2515 (
2516 const Mat<eT>& X,
2517 const uword start_index,
2518 const uword end_index,
2519 Mat<eT>& acc_means,
2520 Mat<eT>& acc_dcovs,
2521 Col<eT>& acc_norm_lhoods,
2522 Col<eT>& gaus_log_lhoods,
2523 eT& progress_log_lhood
2524 )
2525 const
2526 {
2527 arma_extra_debug_sigprint();
2528
2529 progress_log_lhood = eT(0);
2530
2531 acc_means.zeros();
2532 acc_dcovs.zeros();
2533
2534 acc_norm_lhoods.zeros();
2535 gaus_log_lhoods.zeros();
2536
2537 const uword N_dims = means.n_rows;
2538 const uword N_gaus = means.n_cols;
2539
2540 const eT* log_hefts_mem = log_hefts.memptr();
2541 eT* gaus_log_lhoods_mem = gaus_log_lhoods.memptr();
2542
2543
2544 for(uword i=start_index; i <= end_index; i++)
2545 {
2546 const eT* x = X.colptr(i);
2547
2548 for(uword g=0; g < N_gaus; ++g)
2549 {
2550 gaus_log_lhoods_mem[g] = internal_scalar_log_p(x, g) + log_hefts_mem[g];
2551 }
2552
2553 eT log_lhood_sum = gaus_log_lhoods_mem[0];
2554
2555 for(uword g=1; g < N_gaus; ++g)
2556 {
2557 log_lhood_sum = log_add_exp(log_lhood_sum, gaus_log_lhoods_mem[g]);
2558 }
2559
2560 progress_log_lhood += log_lhood_sum;
2561
2562 for(uword g=0; g < N_gaus; ++g)
2563 {
2564 const eT norm_lhood = std::exp(gaus_log_lhoods_mem[g] - log_lhood_sum);
2565
2566 acc_norm_lhoods[g] += norm_lhood;
2567
2568 eT* acc_mean_mem = acc_means.colptr(g);
2569 eT* acc_dcov_mem = acc_dcovs.colptr(g);
2570
2571 for(uword d=0; d < N_dims; ++d)
2572 {
2573 const eT x_d = x[d];
2574 const eT y_d = x_d * norm_lhood;
2575
2576 acc_mean_mem[d] += y_d;
2577 acc_dcov_mem[d] += y_d * x_d; // equivalent to x_d * x_d * norm_lhood
2578 }
2579 }
2580 }
2581
2582 progress_log_lhood /= eT((end_index - start_index) + 1);
2583 }
2584
2585
2586
2587 template<typename eT>
2588 inline
2589 void
em_fix_params(const eT var_floor)2590 gmm_diag<eT>::em_fix_params(const eT var_floor)
2591 {
2592 arma_extra_debug_sigprint();
2593
2594 const uword N_dims = means.n_rows;
2595 const uword N_gaus = means.n_cols;
2596
2597 const eT var_ceiling = std::numeric_limits<eT>::max();
2598
2599 const uword dcovs_n_elem = dcovs.n_elem;
2600 eT* dcovs_mem = access::rw(dcovs).memptr();
2601
2602 for(uword i=0; i < dcovs_n_elem; ++i)
2603 {
2604 eT& var_val = dcovs_mem[i];
2605
2606 if(var_val < var_floor ) { var_val = var_floor; }
2607 else if(var_val > var_ceiling) { var_val = var_ceiling; }
2608 else if(arma_isnan(var_val) ) { var_val = eT(1); }
2609 }
2610
2611
2612 eT* hefts_mem = access::rw(hefts).memptr();
2613
2614 for(uword g1=0; g1 < N_gaus; ++g1)
2615 {
2616 if(hefts_mem[g1] > eT(0))
2617 {
2618 const eT* means_colptr_g1 = means.colptr(g1);
2619
2620 for(uword g2=(g1+1); g2 < N_gaus; ++g2)
2621 {
2622 if( (hefts_mem[g2] > eT(0)) && (std::abs(hefts_mem[g1] - hefts_mem[g2]) <= std::numeric_limits<eT>::epsilon()) )
2623 {
2624 const eT dist = distance<eT,1>::eval(N_dims, means_colptr_g1, means.colptr(g2), means_colptr_g1);
2625
2626 if(dist == eT(0)) { hefts_mem[g2] = eT(0); }
2627 }
2628 }
2629 }
2630 }
2631
2632 const eT heft_floor = std::numeric_limits<eT>::min();
2633 const eT heft_initial = eT(1) / eT(N_gaus);
2634
2635 for(uword i=0; i < N_gaus; ++i)
2636 {
2637 eT& heft_val = hefts_mem[i];
2638
2639 if(heft_val < heft_floor) { heft_val = heft_floor; }
2640 else if(heft_val > eT(1) ) { heft_val = eT(1); }
2641 else if(arma_isnan(heft_val) ) { heft_val = heft_initial; }
2642 }
2643
2644 const eT heft_sum = accu(hefts);
2645
2646 if((heft_sum < (eT(1) - Datum<eT>::eps)) || (heft_sum > (eT(1) + Datum<eT>::eps))) { access::rw(hefts) /= heft_sum; }
2647 }
2648
2649
2650 } // namespace gmm_priv
2651
2652
2653 //! @}
2654