1 // SPDX-License-Identifier: Apache-2.0
2 //
3 // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au)
4 // Copyright 2008-2016 National ICT Australia (NICTA)
5 //
6 // Licensed under the Apache License, Version 2.0 (the "License");
7 // you may not use this file except in compliance with the License.
8 // You may obtain a copy of the License at
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // ------------------------------------------------------------------------
17
18
19 //! \addtogroup gmm_full
20 //! @{
21
22
23 namespace gmm_priv
24 {
25
26
27 template<typename eT>
28 inline
~gmm_full()29 gmm_full<eT>::~gmm_full()
30 {
31 arma_extra_debug_sigprint_this(this);
32
33 arma_type_check(( (is_same_type<eT,float>::value == false) && (is_same_type<eT,double>::value == false) ));
34 }
35
36
37
38 template<typename eT>
39 inline
gmm_full()40 gmm_full<eT>::gmm_full()
41 {
42 arma_extra_debug_sigprint_this(this);
43 }
44
45
46
47 template<typename eT>
48 inline
gmm_full(const gmm_full<eT> & x)49 gmm_full<eT>::gmm_full(const gmm_full<eT>& x)
50 {
51 arma_extra_debug_sigprint_this(this);
52
53 init(x);
54 }
55
56
57
58 template<typename eT>
59 inline
60 gmm_full<eT>&
operator =(const gmm_full<eT> & x)61 gmm_full<eT>::operator=(const gmm_full<eT>& x)
62 {
63 arma_extra_debug_sigprint();
64
65 init(x);
66
67 return *this;
68 }
69
70
71
72 template<typename eT>
73 inline
gmm_full(const gmm_diag<eT> & x)74 gmm_full<eT>::gmm_full(const gmm_diag<eT>& x)
75 {
76 arma_extra_debug_sigprint_this(this);
77
78 init(x);
79 }
80
81
82
83 template<typename eT>
84 inline
85 gmm_full<eT>&
operator =(const gmm_diag<eT> & x)86 gmm_full<eT>::operator=(const gmm_diag<eT>& x)
87 {
88 arma_extra_debug_sigprint();
89
90 init(x);
91
92 return *this;
93 }
94
95
96
97 template<typename eT>
98 inline
gmm_full(const uword in_n_dims,const uword in_n_gaus)99 gmm_full<eT>::gmm_full(const uword in_n_dims, const uword in_n_gaus)
100 {
101 arma_extra_debug_sigprint_this(this);
102
103 init(in_n_dims, in_n_gaus);
104 }
105
106
107
108 template<typename eT>
109 inline
110 void
reset()111 gmm_full<eT>::reset()
112 {
113 arma_extra_debug_sigprint();
114
115 init(0, 0);
116 }
117
118
119
120 template<typename eT>
121 inline
122 void
reset(const uword in_n_dims,const uword in_n_gaus)123 gmm_full<eT>::reset(const uword in_n_dims, const uword in_n_gaus)
124 {
125 arma_extra_debug_sigprint();
126
127 init(in_n_dims, in_n_gaus);
128 }
129
130
131
132 template<typename eT>
133 template<typename T1, typename T2, typename T3>
134 inline
135 void
set_params(const Base<eT,T1> & in_means_expr,const BaseCube<eT,T2> & in_fcovs_expr,const Base<eT,T3> & in_hefts_expr)136 gmm_full<eT>::set_params(const Base<eT,T1>& in_means_expr, const BaseCube<eT,T2>& in_fcovs_expr, const Base<eT,T3>& in_hefts_expr)
137 {
138 arma_extra_debug_sigprint();
139
140 const unwrap <T1> tmp1(in_means_expr.get_ref());
141 const unwrap_cube<T2> tmp2(in_fcovs_expr.get_ref());
142 const unwrap <T3> tmp3(in_hefts_expr.get_ref());
143
144 const Mat <eT>& in_means = tmp1.M;
145 const Cube<eT>& in_fcovs = tmp2.M;
146 const Mat <eT>& in_hefts = tmp3.M;
147
148 arma_debug_check
149 (
150 (in_means.n_cols != in_fcovs.n_slices) || (in_means.n_rows != in_fcovs.n_rows) || (in_fcovs.n_rows != in_fcovs.n_cols) || (in_hefts.n_cols != in_means.n_cols) || (in_hefts.n_rows != 1),
151 "gmm_full::set_params(): given parameters have inconsistent and/or wrong sizes"
152 );
153
154 arma_debug_check( (in_means.is_finite() == false), "gmm_full::set_params(): given means have non-finite values" );
155 arma_debug_check( (in_fcovs.is_finite() == false), "gmm_full::set_params(): given fcovs have non-finite values" );
156 arma_debug_check( (in_hefts.is_finite() == false), "gmm_full::set_params(): given hefts have non-finite values" );
157
158 for(uword g=0; g < in_fcovs.n_slices; ++g)
159 {
160 arma_debug_check( (any(diagvec(in_fcovs.slice(g)) <= eT(0))), "gmm_full::set_params(): given fcovs have negative or zero values on diagonals" );
161 }
162
163 arma_debug_check( (any(vectorise(in_hefts) < eT(0))), "gmm_full::set_params(): given hefts have negative values" );
164
165 const eT s = accu(in_hefts);
166
167 arma_debug_check( ((s < (eT(1) - eT(0.001))) || (s > (eT(1) + eT(0.001)))), "gmm_full::set_params(): sum of given hefts is not 1" );
168
169 access::rw(means) = in_means;
170 access::rw(fcovs) = in_fcovs;
171 access::rw(hefts) = in_hefts;
172
173 init_constants();
174 }
175
176
177
178 template<typename eT>
179 template<typename T1>
180 inline
181 void
set_means(const Base<eT,T1> & in_means_expr)182 gmm_full<eT>::set_means(const Base<eT,T1>& in_means_expr)
183 {
184 arma_extra_debug_sigprint();
185
186 const unwrap<T1> tmp(in_means_expr.get_ref());
187
188 const Mat<eT>& in_means = tmp.M;
189
190 arma_debug_check( (arma::size(in_means) != arma::size(means)), "gmm_full::set_means(): given means have incompatible size" );
191 arma_debug_check( (in_means.is_finite() == false), "gmm_full::set_means(): given means have non-finite values" );
192
193 access::rw(means) = in_means;
194 }
195
196
197
198 template<typename eT>
199 template<typename T1>
200 inline
201 void
set_fcovs(const BaseCube<eT,T1> & in_fcovs_expr)202 gmm_full<eT>::set_fcovs(const BaseCube<eT,T1>& in_fcovs_expr)
203 {
204 arma_extra_debug_sigprint();
205
206 const unwrap_cube<T1> tmp(in_fcovs_expr.get_ref());
207
208 const Cube<eT>& in_fcovs = tmp.M;
209
210 arma_debug_check( (arma::size(in_fcovs) != arma::size(fcovs)), "gmm_full::set_fcovs(): given fcovs have incompatible size" );
211 arma_debug_check( (in_fcovs.is_finite() == false), "gmm_full::set_fcovs(): given fcovs have non-finite values" );
212
213 for(uword i=0; i < in_fcovs.n_slices; ++i)
214 {
215 arma_debug_check( (any(diagvec(in_fcovs.slice(i)) <= eT(0))), "gmm_full::set_fcovs(): given fcovs have negative or zero values on diagonals" );
216 }
217
218 access::rw(fcovs) = in_fcovs;
219
220 init_constants();
221 }
222
223
224
225 template<typename eT>
226 template<typename T1>
227 inline
228 void
set_hefts(const Base<eT,T1> & in_hefts_expr)229 gmm_full<eT>::set_hefts(const Base<eT,T1>& in_hefts_expr)
230 {
231 arma_extra_debug_sigprint();
232
233 const unwrap<T1> tmp(in_hefts_expr.get_ref());
234
235 const Mat<eT>& in_hefts = tmp.M;
236
237 arma_debug_check( (arma::size(in_hefts) != arma::size(hefts)), "gmm_full::set_hefts(): given hefts have incompatible size" );
238 arma_debug_check( (in_hefts.is_finite() == false), "gmm_full::set_hefts(): given hefts have non-finite values" );
239 arma_debug_check( (any(vectorise(in_hefts) < eT(0))), "gmm_full::set_hefts(): given hefts have negative values" );
240
241 const eT s = accu(in_hefts);
242
243 arma_debug_check( ((s < (eT(1) - eT(0.001))) || (s > (eT(1) + eT(0.001)))), "gmm_full::set_hefts(): sum of given hefts is not 1" );
244
245 // make sure all hefts are positive and non-zero
246
247 const eT* in_hefts_mem = in_hefts.memptr();
248 eT* hefts_mem = access::rw(hefts).memptr();
249
250 for(uword i=0; i < hefts.n_elem; ++i)
251 {
252 hefts_mem[i] = (std::max)( in_hefts_mem[i], std::numeric_limits<eT>::min() );
253 }
254
255 access::rw(hefts) /= accu(hefts);
256
257 log_hefts = log(hefts);
258 }
259
260
261
262 template<typename eT>
263 inline
264 uword
n_dims() const265 gmm_full<eT>::n_dims() const
266 {
267 return means.n_rows;
268 }
269
270
271
272 template<typename eT>
273 inline
274 uword
n_gaus() const275 gmm_full<eT>::n_gaus() const
276 {
277 return means.n_cols;
278 }
279
280
281
282 template<typename eT>
283 inline
284 bool
load(const std::string name)285 gmm_full<eT>::load(const std::string name)
286 {
287 arma_extra_debug_sigprint();
288
289 field< Mat<eT> > storage;
290
291 bool status = storage.load(name, arma_binary);
292
293 if( (status == false) || (storage.n_elem < 2) )
294 {
295 reset();
296 arma_debug_warn_level(3, "gmm_full::load(): problem with loading or incompatible format");
297 return false;
298 }
299
300 uword count = 0;
301
302 const Mat<eT>& storage_means = storage(count); ++count;
303 const Mat<eT>& storage_hefts = storage(count); ++count;
304
305 const uword N_dims = storage_means.n_rows;
306 const uword N_gaus = storage_means.n_cols;
307
308 if( (storage.n_elem != (N_gaus + 2)) || (storage_hefts.n_rows != 1) || (storage_hefts.n_cols != N_gaus) )
309 {
310 reset();
311 arma_debug_warn_level(3, "gmm_full::load(): incompatible format");
312 return false;
313 }
314
315 reset(N_dims, N_gaus);
316
317 access::rw(means) = storage_means;
318 access::rw(hefts) = storage_hefts;
319
320 for(uword g=0; g < N_gaus; ++g)
321 {
322 const Mat<eT>& storage_fcov = storage(count); ++count;
323
324 if( (storage_fcov.n_rows != N_dims) || (storage_fcov.n_cols != N_dims) )
325 {
326 reset();
327 arma_debug_warn_level(3, "gmm_full::load(): incompatible format");
328 return false;
329 }
330
331 access::rw(fcovs).slice(g) = storage_fcov;
332 }
333
334 init_constants();
335
336 return true;
337 }
338
339
340
341 template<typename eT>
342 inline
343 bool
save(const std::string name) const344 gmm_full<eT>::save(const std::string name) const
345 {
346 arma_extra_debug_sigprint();
347
348 const uword N_gaus = means.n_cols;
349
350 field< Mat<eT> > storage(2 + N_gaus);
351
352 uword count = 0;
353
354 storage(count) = means; ++count;
355 storage(count) = hefts; ++count;
356
357 for(uword g=0; g < N_gaus; ++g)
358 {
359 storage(count) = fcovs.slice(g); ++count;
360 }
361
362 const bool status = storage.save(name, arma_binary);
363
364 return status;
365 }
366
367
368
369 template<typename eT>
370 inline
371 Col<eT>
generate() const372 gmm_full<eT>::generate() const
373 {
374 arma_extra_debug_sigprint();
375
376 const uword N_dims = means.n_rows;
377 const uword N_gaus = means.n_cols;
378
379 Col<eT> out( (N_gaus > 0) ? N_dims : uword(0), arma_nozeros_indicator() );
380 Col<eT> tmp( (N_gaus > 0) ? N_dims : uword(0), fill::randn );
381
382 if(N_gaus > 0)
383 {
384 const double val = randu<double>();
385
386 double csum = double(0);
387 uword gaus_id = 0;
388
389 for(uword j=0; j < N_gaus; ++j)
390 {
391 csum += hefts[j];
392
393 if(val <= csum) { gaus_id = j; break; }
394 }
395
396 out = chol_fcovs.slice(gaus_id) * tmp;
397 out += means.col(gaus_id);
398 }
399
400 return out;
401 }
402
403
404
405 template<typename eT>
406 inline
407 Mat<eT>
generate(const uword N_vec) const408 gmm_full<eT>::generate(const uword N_vec) const
409 {
410 arma_extra_debug_sigprint();
411
412 const uword N_dims = means.n_rows;
413 const uword N_gaus = means.n_cols;
414
415 Mat<eT> out( ( (N_gaus > 0) ? N_dims : uword(0) ), N_vec, arma_nozeros_indicator() );
416 Mat<eT> tmp( ( (N_gaus > 0) ? N_dims : uword(0) ), N_vec, fill::randn );
417
418 if(N_gaus > 0)
419 {
420 const eT* hefts_mem = hefts.memptr();
421
422 for(uword i=0; i < N_vec; ++i)
423 {
424 const double val = randu<double>();
425
426 double csum = double(0);
427 uword gaus_id = 0;
428
429 for(uword j=0; j < N_gaus; ++j)
430 {
431 csum += hefts_mem[j];
432
433 if(val <= csum) { gaus_id = j; break; }
434 }
435
436 Col<eT> out_vec(out.colptr(i), N_dims, false, true);
437 Col<eT> tmp_vec(tmp.colptr(i), N_dims, false, true);
438
439 out_vec = chol_fcovs.slice(gaus_id) * tmp_vec;
440 out_vec += means.col(gaus_id);
441 }
442 }
443
444 return out;
445 }
446
447
448
449 template<typename eT>
450 template<typename T1>
451 inline
452 eT
log_p(const T1 & expr,const gmm_empty_arg & junk1,typename enable_if<((is_arma_type<T1>::value)&& (resolves_to_colvector<T1>::value==true))>::result * junk2) const453 gmm_full<eT>::log_p(const T1& expr, const gmm_empty_arg& junk1, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == true))>::result* junk2) const
454 {
455 arma_extra_debug_sigprint();
456 arma_ignore(junk1);
457 arma_ignore(junk2);
458
459 const uword N_dims = means.n_rows;
460
461 const quasi_unwrap<T1> U(expr);
462
463 arma_debug_check( (U.M.n_rows != N_dims), "gmm_full::log_p(): incompatible dimensions" );
464
465 return internal_scalar_log_p( U.M.memptr() );
466 }
467
468
469
470 template<typename eT>
471 template<typename T1>
472 inline
473 eT
log_p(const T1 & expr,const uword gaus_id,typename enable_if<((is_arma_type<T1>::value)&& (resolves_to_colvector<T1>::value==true))>::result * junk2) const474 gmm_full<eT>::log_p(const T1& expr, const uword gaus_id, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == true))>::result* junk2) const
475 {
476 arma_extra_debug_sigprint();
477 arma_ignore(junk2);
478
479 const uword N_dims = means.n_rows;
480
481 const quasi_unwrap<T1> U(expr);
482
483 arma_debug_check( (U.M.n_rows != N_dims), "gmm_full::log_p(): incompatible dimensions" );
484 arma_debug_check( (gaus_id >= means.n_cols), "gmm_full::log_p(): specified gaussian is out of range" );
485
486 return internal_scalar_log_p( U.M.memptr(), gaus_id );
487 }
488
489
490
491 template<typename eT>
492 template<typename T1>
493 inline
494 Row<eT>
log_p(const T1 & expr,const gmm_empty_arg & junk1,typename enable_if<((is_arma_type<T1>::value)&& (resolves_to_colvector<T1>::value==false))>::result * junk2) const495 gmm_full<eT>::log_p(const T1& expr, const gmm_empty_arg& junk1, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == false))>::result* junk2) const
496 {
497 arma_extra_debug_sigprint();
498 arma_ignore(junk1);
499 arma_ignore(junk2);
500
501 const quasi_unwrap<T1> tmp(expr);
502
503 const Mat<eT>& X = tmp.M;
504
505 return internal_vec_log_p(X);
506 }
507
508
509
510 template<typename eT>
511 template<typename T1>
512 inline
513 Row<eT>
log_p(const T1 & expr,const uword gaus_id,typename enable_if<((is_arma_type<T1>::value)&& (resolves_to_colvector<T1>::value==false))>::result * junk2) const514 gmm_full<eT>::log_p(const T1& expr, const uword gaus_id, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == false))>::result* junk2) const
515 {
516 arma_extra_debug_sigprint();
517 arma_ignore(junk2);
518
519 const quasi_unwrap<T1> tmp(expr);
520
521 const Mat<eT>& X = tmp.M;
522
523 return internal_vec_log_p(X, gaus_id);
524 }
525
526
527
528 template<typename eT>
529 template<typename T1>
530 inline
531 eT
sum_log_p(const Base<eT,T1> & expr) const532 gmm_full<eT>::sum_log_p(const Base<eT,T1>& expr) const
533 {
534 arma_extra_debug_sigprint();
535
536 const quasi_unwrap<T1> tmp(expr.get_ref());
537
538 const Mat<eT>& X = tmp.M;
539
540 return internal_sum_log_p(X);
541 }
542
543
544
545 template<typename eT>
546 template<typename T1>
547 inline
548 eT
sum_log_p(const Base<eT,T1> & expr,const uword gaus_id) const549 gmm_full<eT>::sum_log_p(const Base<eT,T1>& expr, const uword gaus_id) const
550 {
551 arma_extra_debug_sigprint();
552
553 const quasi_unwrap<T1> tmp(expr.get_ref());
554
555 const Mat<eT>& X = tmp.M;
556
557 return internal_sum_log_p(X, gaus_id);
558 }
559
560
561
562 template<typename eT>
563 template<typename T1>
564 inline
565 eT
avg_log_p(const Base<eT,T1> & expr) const566 gmm_full<eT>::avg_log_p(const Base<eT,T1>& expr) const
567 {
568 arma_extra_debug_sigprint();
569
570 const quasi_unwrap<T1> tmp(expr.get_ref());
571
572 const Mat<eT>& X = tmp.M;
573
574 return internal_avg_log_p(X);
575 }
576
577
578
579 template<typename eT>
580 template<typename T1>
581 inline
582 eT
avg_log_p(const Base<eT,T1> & expr,const uword gaus_id) const583 gmm_full<eT>::avg_log_p(const Base<eT,T1>& expr, const uword gaus_id) const
584 {
585 arma_extra_debug_sigprint();
586
587 const quasi_unwrap<T1> tmp(expr.get_ref());
588
589 const Mat<eT>& X = tmp.M;
590
591 return internal_avg_log_p(X, gaus_id);
592 }
593
594
595
596 template<typename eT>
597 template<typename T1>
598 inline
599 uword
assign(const T1 & expr,const gmm_dist_mode & dist,typename enable_if<((is_arma_type<T1>::value)&& (resolves_to_colvector<T1>::value==true))>::result * junk) const600 gmm_full<eT>::assign(const T1& expr, const gmm_dist_mode& dist, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == true))>::result* junk) const
601 {
602 arma_extra_debug_sigprint();
603 arma_ignore(junk);
604
605 const quasi_unwrap<T1> tmp(expr);
606
607 const Mat<eT>& X = tmp.M;
608
609 return internal_scalar_assign(X, dist);
610 }
611
612
613
614 template<typename eT>
615 template<typename T1>
616 inline
617 urowvec
assign(const T1 & expr,const gmm_dist_mode & dist,typename enable_if<((is_arma_type<T1>::value)&& (resolves_to_colvector<T1>::value==false))>::result * junk) const618 gmm_full<eT>::assign(const T1& expr, const gmm_dist_mode& dist, typename enable_if<((is_arma_type<T1>::value) && (resolves_to_colvector<T1>::value == false))>::result* junk) const
619 {
620 arma_extra_debug_sigprint();
621 arma_ignore(junk);
622
623 urowvec out;
624
625 const quasi_unwrap<T1> tmp(expr);
626
627 const Mat<eT>& X = tmp.M;
628
629 internal_vec_assign(out, X, dist);
630
631 return out;
632 }
633
634
635
636 template<typename eT>
637 template<typename T1>
638 inline
639 urowvec
raw_hist(const Base<eT,T1> & expr,const gmm_dist_mode & dist_mode) const640 gmm_full<eT>::raw_hist(const Base<eT,T1>& expr, const gmm_dist_mode& dist_mode) const
641 {
642 arma_extra_debug_sigprint();
643
644 const unwrap<T1> tmp(expr.get_ref());
645 const Mat<eT>& X = tmp.M;
646
647 arma_debug_check( (X.n_rows != means.n_rows), "gmm_full::raw_hist(): incompatible dimensions" );
648
649 arma_debug_check( ((dist_mode != eucl_dist) && (dist_mode != prob_dist)), "gmm_full::raw_hist(): unsupported distance mode" );
650
651 urowvec hist;
652
653 internal_raw_hist(hist, X, dist_mode);
654
655 return hist;
656 }
657
658
659
660 template<typename eT>
661 template<typename T1>
662 inline
663 Row<eT>
norm_hist(const Base<eT,T1> & expr,const gmm_dist_mode & dist_mode) const664 gmm_full<eT>::norm_hist(const Base<eT,T1>& expr, const gmm_dist_mode& dist_mode) const
665 {
666 arma_extra_debug_sigprint();
667
668 const unwrap<T1> tmp(expr.get_ref());
669 const Mat<eT>& X = tmp.M;
670
671 arma_debug_check( (X.n_rows != means.n_rows), "gmm_full::norm_hist(): incompatible dimensions" );
672
673 arma_debug_check( ((dist_mode != eucl_dist) && (dist_mode != prob_dist)), "gmm_full::norm_hist(): unsupported distance mode" );
674
675 urowvec hist;
676
677 internal_raw_hist(hist, X, dist_mode);
678
679 const uword hist_n_elem = hist.n_elem;
680 const uword* hist_mem = hist.memptr();
681
682 eT acc = eT(0);
683 for(uword i=0; i<hist_n_elem; ++i) { acc += eT(hist_mem[i]); }
684
685 if(acc == eT(0)) { acc = eT(1); }
686
687 Row<eT> out(hist_n_elem, arma_nozeros_indicator());
688
689 eT* out_mem = out.memptr();
690
691 for(uword i=0; i<hist_n_elem; ++i) { out_mem[i] = eT(hist_mem[i]) / acc; }
692
693 return out;
694 }
695
696
697
698 template<typename eT>
699 template<typename T1>
700 inline
701 bool
learn(const Base<eT,T1> & data,const uword N_gaus,const gmm_dist_mode & dist_mode,const gmm_seed_mode & seed_mode,const uword km_iter,const uword em_iter,const eT var_floor,const bool print_mode)702 gmm_full<eT>::learn
703 (
704 const Base<eT,T1>& data,
705 const uword N_gaus,
706 const gmm_dist_mode& dist_mode,
707 const gmm_seed_mode& seed_mode,
708 const uword km_iter,
709 const uword em_iter,
710 const eT var_floor,
711 const bool print_mode
712 )
713 {
714 arma_extra_debug_sigprint();
715
716 const bool dist_mode_ok = (dist_mode == eucl_dist) || (dist_mode == maha_dist);
717
718 const bool seed_mode_ok = \
719 (seed_mode == keep_existing)
720 || (seed_mode == static_subset)
721 || (seed_mode == static_spread)
722 || (seed_mode == random_subset)
723 || (seed_mode == random_spread);
724
725 arma_debug_check( (dist_mode_ok == false), "gmm_full::learn(): dist_mode must be eucl_dist or maha_dist" );
726 arma_debug_check( (seed_mode_ok == false), "gmm_full::learn(): unknown seed_mode" );
727 arma_debug_check( (var_floor < eT(0) ), "gmm_full::learn(): variance floor is negative" );
728
729 const unwrap<T1> tmp_X(data.get_ref());
730 const Mat<eT>& X = tmp_X.M;
731
732 if(X.is_empty() ) { arma_debug_warn_level(3, "gmm_full::learn(): given matrix is empty" ); return false; }
733 if(X.is_finite() == false) { arma_debug_warn_level(3, "gmm_full::learn(): given matrix has non-finite values"); return false; }
734
735 if(N_gaus == 0) { reset(); return true; }
736
737 if(dist_mode == maha_dist)
738 {
739 mah_aux = var(X,1,1);
740
741 const uword mah_aux_n_elem = mah_aux.n_elem;
742 eT* mah_aux_mem = mah_aux.memptr();
743
744 for(uword i=0; i < mah_aux_n_elem; ++i)
745 {
746 const eT val = mah_aux_mem[i];
747
748 mah_aux_mem[i] = ((val != eT(0)) && arma_isfinite(val)) ? eT(1) / val : eT(1);
749 }
750 }
751
752
753 // copy current model, in case of failure by k-means and/or EM
754
755 const gmm_full<eT> orig = (*this);
756
757
758 // initial means
759
760 if(seed_mode == keep_existing)
761 {
762 if(means.is_empty() ) { arma_debug_warn_level(3, "gmm_full::learn(): no existing means" ); return false; }
763 if(X.n_rows != means.n_rows) { arma_debug_warn_level(3, "gmm_full::learn(): dimensionality mismatch"); return false; }
764
765 // TODO: also check for number of vectors?
766 }
767 else
768 {
769 if(X.n_cols < N_gaus) { arma_debug_warn_level(3, "gmm_full::learn(): number of vectors is less than number of gaussians"); return false; }
770
771 reset(X.n_rows, N_gaus);
772
773 if(print_mode) { get_cout_stream() << "gmm_full::learn(): generating initial means\n"; get_cout_stream().flush(); }
774
775 if(dist_mode == eucl_dist) { generate_initial_means<1>(X, seed_mode); }
776 else if(dist_mode == maha_dist) { generate_initial_means<2>(X, seed_mode); }
777 }
778
779
780 // k-means
781
782 if(km_iter > 0)
783 {
784 const arma_ostream_state stream_state(get_cout_stream());
785
786 bool status = false;
787
788 if(dist_mode == eucl_dist) { status = km_iterate<1>(X, km_iter, print_mode); }
789 else if(dist_mode == maha_dist) { status = km_iterate<2>(X, km_iter, print_mode); }
790
791 stream_state.restore(get_cout_stream());
792
793 if(status == false) { arma_debug_warn_level(3, "gmm_full::learn(): k-means algorithm failed; not enough data, or too many gaussians requested"); init(orig); return false; }
794 }
795
796
797 // initial fcovs
798
799 const eT var_floor_actual = (eT(var_floor) > eT(0)) ? eT(var_floor) : std::numeric_limits<eT>::min();
800
801 if(seed_mode != keep_existing)
802 {
803 if(print_mode) { get_cout_stream() << "gmm_full::learn(): generating initial covariances\n"; get_cout_stream().flush(); }
804
805 if(dist_mode == eucl_dist) { generate_initial_params<1>(X, var_floor_actual); }
806 else if(dist_mode == maha_dist) { generate_initial_params<2>(X, var_floor_actual); }
807 }
808
809
810 // EM algorithm
811
812 if(em_iter > 0)
813 {
814 const arma_ostream_state stream_state(get_cout_stream());
815
816 const bool status = em_iterate(X, em_iter, var_floor_actual, print_mode);
817
818 stream_state.restore(get_cout_stream());
819
820 if(status == false) { arma_debug_warn_level(3, "gmm_full::learn(): EM algorithm failed"); init(orig); return false; }
821 }
822
823 mah_aux.reset();
824
825 init_constants();
826
827 return true;
828 }
829
830
831
832 //
833 //
834 //
835
836
837
838 template<typename eT>
839 inline
840 void
init(const gmm_full<eT> & x)841 gmm_full<eT>::init(const gmm_full<eT>& x)
842 {
843 arma_extra_debug_sigprint();
844
845 gmm_full<eT>& t = *this;
846
847 if(&t != &x)
848 {
849 access::rw(t.means) = x.means;
850 access::rw(t.fcovs) = x.fcovs;
851 access::rw(t.hefts) = x.hefts;
852
853 init_constants();
854 }
855 }
856
857
858
859 template<typename eT>
860 inline
861 void
init(const gmm_diag<eT> & x)862 gmm_full<eT>::init(const gmm_diag<eT>& x)
863 {
864 arma_extra_debug_sigprint();
865
866 access::rw(hefts) = x.hefts;
867 access::rw(means) = x.means;
868
869 const uword N_dims = x.means.n_rows;
870 const uword N_gaus = x.means.n_cols;
871
872 access::rw(fcovs).zeros(N_dims,N_dims,N_gaus);
873
874 for(uword g=0; g < N_gaus; ++g)
875 {
876 Mat<eT>& fcov = access::rw(fcovs).slice(g);
877
878 const eT* dcov_mem = x.dcovs.colptr(g);
879
880 for(uword d=0; d < N_dims; ++d)
881 {
882 fcov.at(d,d) = dcov_mem[d];
883 }
884 }
885
886 init_constants();
887 }
888
889
890
891 template<typename eT>
892 inline
893 void
init(const uword in_n_dims,const uword in_n_gaus)894 gmm_full<eT>::init(const uword in_n_dims, const uword in_n_gaus)
895 {
896 arma_extra_debug_sigprint();
897
898 access::rw(means).zeros(in_n_dims, in_n_gaus);
899
900 access::rw(fcovs).zeros(in_n_dims, in_n_dims, in_n_gaus);
901
902 for(uword g=0; g < in_n_gaus; ++g)
903 {
904 access::rw(fcovs).slice(g).diag().ones();
905 }
906
907 access::rw(hefts).set_size(in_n_gaus);
908 access::rw(hefts).fill(eT(1) / eT(in_n_gaus));
909
910 init_constants();
911 }
912
913
914
915 template<typename eT>
916 inline
917 void
init_constants(const bool calc_chol)918 gmm_full<eT>::init_constants(const bool calc_chol)
919 {
920 arma_extra_debug_sigprint();
921
922 const uword N_dims = means.n_rows;
923 const uword N_gaus = means.n_cols;
924
925 const eT tmp = (eT(N_dims)/eT(2)) * std::log(eT(2) * Datum<eT>::pi);
926
927 //
928
929 inv_fcovs.copy_size(fcovs);
930 log_det_etc.set_size(N_gaus);
931
932 Mat<eT> tmp_inv;
933
934 for(uword g=0; g < N_gaus; ++g)
935 {
936 const Mat<eT>& fcov = fcovs.slice(g);
937 Mat<eT>& inv_fcov = inv_fcovs.slice(g);
938
939 //const bool inv_ok = auxlib::inv(tmp_inv, fcov);
940 const bool inv_ok = auxlib::inv_sympd(tmp_inv, fcov);
941
942 eT log_det_val = eT(0);
943 eT log_det_sign = eT(0);
944
945 const bool log_det_status = log_det(log_det_val, log_det_sign, fcov);
946
947 const bool log_det_ok = ( log_det_status && (arma_isfinite(log_det_val)) && (log_det_sign > eT(0)) );
948
949 if(inv_ok && log_det_ok)
950 {
951 inv_fcov = tmp_inv;
952 }
953 else
954 {
955 // last resort: treat the covariance matrix as diagonal
956
957 inv_fcov.zeros();
958
959 log_det_val = eT(0);
960
961 for(uword d=0; d < N_dims; ++d)
962 {
963 const eT sanitised_val = (std::max)( eT(fcov.at(d,d)), eT(std::numeric_limits<eT>::min()) );
964
965 inv_fcov.at(d,d) = eT(1) / sanitised_val;
966
967 log_det_val += std::log(sanitised_val);
968 }
969 }
970
971 log_det_etc[g] = eT(-1) * ( tmp + eT(0.5) * log_det_val );
972 }
973
974 //
975
976 eT* hefts_mem = access::rw(hefts).memptr();
977
978 for(uword g=0; g < N_gaus; ++g)
979 {
980 hefts_mem[g] = (std::max)( hefts_mem[g], std::numeric_limits<eT>::min() );
981 }
982
983 log_hefts = log(hefts);
984
985
986 if(calc_chol)
987 {
988 chol_fcovs.copy_size(fcovs);
989
990 Mat<eT> tmp_chol;
991
992 for(uword g=0; g < N_gaus; ++g)
993 {
994 const Mat<eT>& fcov = fcovs.slice(g);
995 Mat<eT>& chol_fcov = chol_fcovs.slice(g);
996
997 const uword chol_layout = 1; // indicates "lower"
998
999 const bool chol_ok = op_chol::apply_direct(tmp_chol, fcov, chol_layout);
1000
1001 if(chol_ok)
1002 {
1003 chol_fcov = tmp_chol;
1004 }
1005 else
1006 {
1007 // last resort: treat the covariance matrix as diagonal
1008
1009 chol_fcov.zeros();
1010
1011 for(uword d=0; d < N_dims; ++d)
1012 {
1013 const eT sanitised_val = (std::max)( eT(fcov.at(d,d)), eT(std::numeric_limits<eT>::min()) );
1014
1015 chol_fcov.at(d,d) = std::sqrt(sanitised_val);
1016 }
1017 }
1018 }
1019 }
1020 }
1021
1022
1023
1024 template<typename eT>
1025 inline
1026 umat
internal_gen_boundaries(const uword N) const1027 gmm_full<eT>::internal_gen_boundaries(const uword N) const
1028 {
1029 arma_extra_debug_sigprint();
1030
1031 #if defined(ARMA_USE_OPENMP)
1032 const uword n_threads_avail = uword(omp_get_max_threads());
1033 const uword n_threads = (n_threads_avail > 0) ? ( (n_threads_avail <= N) ? n_threads_avail : 1 ) : 1;
1034 #else
1035 static constexpr uword n_threads = 1;
1036 #endif
1037
1038 // get_cout_stream() << "gmm_full::internal_gen_boundaries(): n_threads: " << n_threads << '\n';
1039
1040 umat boundaries(2, n_threads, arma_nozeros_indicator());
1041
1042 if(N > 0)
1043 {
1044 const uword chunk_size = N / n_threads;
1045
1046 uword count = 0;
1047
1048 for(uword t=0; t<n_threads; t++)
1049 {
1050 boundaries.at(0,t) = count;
1051
1052 count += chunk_size;
1053
1054 boundaries.at(1,t) = count-1;
1055 }
1056
1057 boundaries.at(1,n_threads-1) = N - 1;
1058 }
1059 else
1060 {
1061 boundaries.zeros();
1062 }
1063
1064 // get_cout_stream() << "gmm_full::internal_gen_boundaries(): boundaries: " << '\n' << boundaries << '\n';
1065
1066 return boundaries;
1067 }
1068
1069
1070
1071 template<typename eT>
1072 inline
1073 eT
internal_scalar_log_p(const eT * x) const1074 gmm_full<eT>::internal_scalar_log_p(const eT* x) const
1075 {
1076 arma_extra_debug_sigprint();
1077
1078 const eT* log_hefts_mem = log_hefts.mem;
1079
1080 const uword N_gaus = means.n_cols;
1081
1082 if(N_gaus > 0)
1083 {
1084 eT log_sum = internal_scalar_log_p(x, 0) + log_hefts_mem[0];
1085
1086 for(uword g=1; g < N_gaus; ++g)
1087 {
1088 const eT log_val = internal_scalar_log_p(x, g) + log_hefts_mem[g];
1089
1090 log_sum = log_add_exp(log_sum, log_val);
1091 }
1092
1093 return log_sum;
1094 }
1095 else
1096 {
1097 return -Datum<eT>::inf;
1098 }
1099 }
1100
1101
1102
1103 template<typename eT>
1104 inline
1105 eT
internal_scalar_log_p(const eT * x,const uword g) const1106 gmm_full<eT>::internal_scalar_log_p(const eT* x, const uword g) const
1107 {
1108 arma_extra_debug_sigprint();
1109
1110 const uword N_dims = means.n_rows;
1111 const eT* mean_mem = means.colptr(g);
1112
1113 eT outer_acc = eT(0);
1114
1115 const eT* inv_fcov_coldata = inv_fcovs.slice(g).memptr();
1116
1117 for(uword i=0; i < N_dims; ++i)
1118 {
1119 eT inner_acc = eT(0);
1120
1121 for(uword j=0; j < N_dims; ++j)
1122 {
1123 inner_acc += (x[j] - mean_mem[j]) * inv_fcov_coldata[j];
1124 }
1125
1126 inv_fcov_coldata += N_dims;
1127
1128 outer_acc += inner_acc * (x[i] - mean_mem[i]);
1129 }
1130
1131 return eT(-0.5)*outer_acc + log_det_etc.mem[g];
1132 }
1133
1134
1135
1136 template<typename eT>
1137 inline
1138 Row<eT>
internal_vec_log_p(const Mat<eT> & X) const1139 gmm_full<eT>::internal_vec_log_p(const Mat<eT>& X) const
1140 {
1141 arma_extra_debug_sigprint();
1142
1143 const uword N_dims = means.n_rows;
1144 const uword N_samples = X.n_cols;
1145
1146 arma_debug_check( (X.n_rows != N_dims), "gmm_full::log_p(): incompatible dimensions" );
1147
1148 Row<eT> out(N_samples, arma_nozeros_indicator());
1149
1150 if(N_samples > 0)
1151 {
1152 #if defined(ARMA_USE_OPENMP)
1153 {
1154 const umat boundaries = internal_gen_boundaries(N_samples);
1155
1156 const uword n_threads = boundaries.n_cols;
1157
1158 #pragma omp parallel for schedule(static)
1159 for(uword t=0; t < n_threads; ++t)
1160 {
1161 const uword start_index = boundaries.at(0,t);
1162 const uword end_index = boundaries.at(1,t);
1163
1164 eT* out_mem = out.memptr();
1165
1166 for(uword i=start_index; i <= end_index; ++i)
1167 {
1168 out_mem[i] = internal_scalar_log_p( X.colptr(i) );
1169 }
1170 }
1171 }
1172 #else
1173 {
1174 eT* out_mem = out.memptr();
1175
1176 for(uword i=0; i < N_samples; ++i)
1177 {
1178 out_mem[i] = internal_scalar_log_p( X.colptr(i) );
1179 }
1180 }
1181 #endif
1182 }
1183
1184 return out;
1185 }
1186
1187
1188
1189 template<typename eT>
1190 inline
1191 Row<eT>
internal_vec_log_p(const Mat<eT> & X,const uword gaus_id) const1192 gmm_full<eT>::internal_vec_log_p(const Mat<eT>& X, const uword gaus_id) const
1193 {
1194 arma_extra_debug_sigprint();
1195
1196 const uword N_dims = means.n_rows;
1197 const uword N_samples = X.n_cols;
1198
1199 arma_debug_check( (X.n_rows != N_dims), "gmm_full::log_p(): incompatible dimensions" );
1200 arma_debug_check( (gaus_id >= means.n_cols), "gmm_full::log_p(): specified gaussian is out of range" );
1201
1202 Row<eT> out(N_samples, arma_nozeros_indicator());
1203
1204 if(N_samples > 0)
1205 {
1206 #if defined(ARMA_USE_OPENMP)
1207 {
1208 const umat boundaries = internal_gen_boundaries(N_samples);
1209
1210 const uword n_threads = boundaries.n_cols;
1211
1212 #pragma omp parallel for schedule(static)
1213 for(uword t=0; t < n_threads; ++t)
1214 {
1215 const uword start_index = boundaries.at(0,t);
1216 const uword end_index = boundaries.at(1,t);
1217
1218 eT* out_mem = out.memptr();
1219
1220 for(uword i=start_index; i <= end_index; ++i)
1221 {
1222 out_mem[i] = internal_scalar_log_p( X.colptr(i), gaus_id );
1223 }
1224 }
1225 }
1226 #else
1227 {
1228 eT* out_mem = out.memptr();
1229
1230 for(uword i=0; i < N_samples; ++i)
1231 {
1232 out_mem[i] = internal_scalar_log_p( X.colptr(i), gaus_id );
1233 }
1234 }
1235 #endif
1236 }
1237
1238 return out;
1239 }
1240
1241
1242
1243 template<typename eT>
1244 inline
1245 eT
internal_sum_log_p(const Mat<eT> & X) const1246 gmm_full<eT>::internal_sum_log_p(const Mat<eT>& X) const
1247 {
1248 arma_extra_debug_sigprint();
1249
1250 arma_debug_check( (X.n_rows != means.n_rows), "gmm_full::sum_log_p(): incompatible dimensions" );
1251
1252 const uword N = X.n_cols;
1253
1254 if(N == 0) { return (-Datum<eT>::inf); }
1255
1256
1257 #if defined(ARMA_USE_OPENMP)
1258 {
1259 const umat boundaries = internal_gen_boundaries(N);
1260
1261 const uword n_threads = boundaries.n_cols;
1262
1263 Col<eT> t_accs(n_threads, arma_zeros_indicator());
1264
1265 #pragma omp parallel for schedule(static)
1266 for(uword t=0; t < n_threads; ++t)
1267 {
1268 const uword start_index = boundaries.at(0,t);
1269 const uword end_index = boundaries.at(1,t);
1270
1271 eT t_acc = eT(0);
1272
1273 for(uword i=start_index; i <= end_index; ++i)
1274 {
1275 t_acc += internal_scalar_log_p( X.colptr(i) );
1276 }
1277
1278 t_accs[t] = t_acc;
1279 }
1280
1281 return eT(accu(t_accs));
1282 }
1283 #else
1284 {
1285 eT acc = eT(0);
1286
1287 for(uword i=0; i<N; ++i)
1288 {
1289 acc += internal_scalar_log_p( X.colptr(i) );
1290 }
1291
1292 return acc;
1293 }
1294 #endif
1295 }
1296
1297
1298
1299 template<typename eT>
1300 inline
1301 eT
internal_sum_log_p(const Mat<eT> & X,const uword gaus_id) const1302 gmm_full<eT>::internal_sum_log_p(const Mat<eT>& X, const uword gaus_id) const
1303 {
1304 arma_extra_debug_sigprint();
1305
1306 arma_debug_check( (X.n_rows != means.n_rows), "gmm_full::sum_log_p(): incompatible dimensions" );
1307 arma_debug_check( (gaus_id >= means.n_cols), "gmm_full::sum_log_p(): specified gaussian is out of range" );
1308
1309 const uword N = X.n_cols;
1310
1311 if(N == 0) { return (-Datum<eT>::inf); }
1312
1313
1314 #if defined(ARMA_USE_OPENMP)
1315 {
1316 const umat boundaries = internal_gen_boundaries(N);
1317
1318 const uword n_threads = boundaries.n_cols;
1319
1320 Col<eT> t_accs(n_threads, arma_zeros_indicator());
1321
1322 #pragma omp parallel for schedule(static)
1323 for(uword t=0; t < n_threads; ++t)
1324 {
1325 const uword start_index = boundaries.at(0,t);
1326 const uword end_index = boundaries.at(1,t);
1327
1328 eT t_acc = eT(0);
1329
1330 for(uword i=start_index; i <= end_index; ++i)
1331 {
1332 t_acc += internal_scalar_log_p( X.colptr(i), gaus_id );
1333 }
1334
1335 t_accs[t] = t_acc;
1336 }
1337
1338 return eT(accu(t_accs));
1339 }
1340 #else
1341 {
1342 eT acc = eT(0);
1343
1344 for(uword i=0; i<N; ++i)
1345 {
1346 acc += internal_scalar_log_p( X.colptr(i), gaus_id );
1347 }
1348
1349 return acc;
1350 }
1351 #endif
1352 }
1353
1354
1355
1356 template<typename eT>
1357 inline
1358 eT
internal_avg_log_p(const Mat<eT> & X) const1359 gmm_full<eT>::internal_avg_log_p(const Mat<eT>& X) const
1360 {
1361 arma_extra_debug_sigprint();
1362
1363 const uword N_dims = means.n_rows;
1364 const uword N_samples = X.n_cols;
1365
1366 arma_debug_check( (X.n_rows != N_dims), "gmm_full::avg_log_p(): incompatible dimensions" );
1367
1368 if(N_samples == 0) { return (-Datum<eT>::inf); }
1369
1370
1371 #if defined(ARMA_USE_OPENMP)
1372 {
1373 const umat boundaries = internal_gen_boundaries(N_samples);
1374
1375 const uword n_threads = boundaries.n_cols;
1376
1377 field< running_mean_scalar<eT> > t_running_means(n_threads);
1378
1379
1380 #pragma omp parallel for schedule(static)
1381 for(uword t=0; t < n_threads; ++t)
1382 {
1383 const uword start_index = boundaries.at(0,t);
1384 const uword end_index = boundaries.at(1,t);
1385
1386 running_mean_scalar<eT>& current_running_mean = t_running_means[t];
1387
1388 for(uword i=start_index; i <= end_index; ++i)
1389 {
1390 current_running_mean( internal_scalar_log_p( X.colptr(i) ) );
1391 }
1392 }
1393
1394
1395 eT avg = eT(0);
1396
1397 for(uword t=0; t < n_threads; ++t)
1398 {
1399 running_mean_scalar<eT>& current_running_mean = t_running_means[t];
1400
1401 const eT w = eT(current_running_mean.count()) / eT(N_samples);
1402
1403 avg += w * current_running_mean.mean();
1404 }
1405
1406 return avg;
1407 }
1408 #else
1409 {
1410 running_mean_scalar<eT> running_mean;
1411
1412 for(uword i=0; i < N_samples; ++i)
1413 {
1414 running_mean( internal_scalar_log_p( X.colptr(i) ) );
1415 }
1416
1417 return running_mean.mean();
1418 }
1419 #endif
1420 }
1421
1422
1423
1424 template<typename eT>
1425 inline
1426 eT
internal_avg_log_p(const Mat<eT> & X,const uword gaus_id) const1427 gmm_full<eT>::internal_avg_log_p(const Mat<eT>& X, const uword gaus_id) const
1428 {
1429 arma_extra_debug_sigprint();
1430
1431 const uword N_dims = means.n_rows;
1432 const uword N_samples = X.n_cols;
1433
1434 arma_debug_check( (X.n_rows != N_dims), "gmm_full::avg_log_p(): incompatible dimensions" );
1435 arma_debug_check( (gaus_id >= means.n_cols), "gmm_full::avg_log_p(): specified gaussian is out of range" );
1436
1437 if(N_samples == 0) { return (-Datum<eT>::inf); }
1438
1439
1440 #if defined(ARMA_USE_OPENMP)
1441 {
1442 const umat boundaries = internal_gen_boundaries(N_samples);
1443
1444 const uword n_threads = boundaries.n_cols;
1445
1446 field< running_mean_scalar<eT> > t_running_means(n_threads);
1447
1448
1449 #pragma omp parallel for schedule(static)
1450 for(uword t=0; t < n_threads; ++t)
1451 {
1452 const uword start_index = boundaries.at(0,t);
1453 const uword end_index = boundaries.at(1,t);
1454
1455 running_mean_scalar<eT>& current_running_mean = t_running_means[t];
1456
1457 for(uword i=start_index; i <= end_index; ++i)
1458 {
1459 current_running_mean( internal_scalar_log_p( X.colptr(i), gaus_id) );
1460 }
1461 }
1462
1463
1464 eT avg = eT(0);
1465
1466 for(uword t=0; t < n_threads; ++t)
1467 {
1468 running_mean_scalar<eT>& current_running_mean = t_running_means[t];
1469
1470 const eT w = eT(current_running_mean.count()) / eT(N_samples);
1471
1472 avg += w * current_running_mean.mean();
1473 }
1474
1475 return avg;
1476 }
1477 #else
1478 {
1479 running_mean_scalar<eT> running_mean;
1480
1481 for(uword i=0; i<N_samples; ++i)
1482 {
1483 running_mean( internal_scalar_log_p( X.colptr(i), gaus_id ) );
1484 }
1485
1486 return running_mean.mean();
1487 }
1488 #endif
1489 }
1490
1491
1492
1493 template<typename eT>
1494 inline
1495 uword
internal_scalar_assign(const Mat<eT> & X,const gmm_dist_mode & dist_mode) const1496 gmm_full<eT>::internal_scalar_assign(const Mat<eT>& X, const gmm_dist_mode& dist_mode) const
1497 {
1498 arma_extra_debug_sigprint();
1499
1500 const uword N_dims = means.n_rows;
1501 const uword N_gaus = means.n_cols;
1502
1503 arma_debug_check( (X.n_rows != N_dims), "gmm_full::assign(): incompatible dimensions" );
1504 arma_debug_check( (N_gaus == 0), "gmm_full::assign(): model has no means" );
1505
1506 const eT* X_mem = X.colptr(0);
1507
1508 if(dist_mode == eucl_dist)
1509 {
1510 eT best_dist = Datum<eT>::inf;
1511 uword best_g = 0;
1512
1513 for(uword g=0; g < N_gaus; ++g)
1514 {
1515 const eT tmp_dist = distance<eT,1>::eval(N_dims, X_mem, means.colptr(g), X_mem);
1516
1517 if(tmp_dist <= best_dist)
1518 {
1519 best_dist = tmp_dist;
1520 best_g = g;
1521 }
1522 }
1523
1524 return best_g;
1525 }
1526 else
1527 if(dist_mode == prob_dist)
1528 {
1529 const eT* log_hefts_mem = log_hefts.memptr();
1530
1531 eT best_p = -Datum<eT>::inf;
1532 uword best_g = 0;
1533
1534 for(uword g=0; g < N_gaus; ++g)
1535 {
1536 const eT tmp_p = internal_scalar_log_p(X_mem, g) + log_hefts_mem[g];
1537
1538 if(tmp_p >= best_p)
1539 {
1540 best_p = tmp_p;
1541 best_g = g;
1542 }
1543 }
1544
1545 return best_g;
1546 }
1547 else
1548 {
1549 arma_debug_check(true, "gmm_full::assign(): unsupported distance mode");
1550 }
1551
1552 return uword(0);
1553 }
1554
1555
1556
1557 template<typename eT>
1558 inline
1559 void
internal_vec_assign(urowvec & out,const Mat<eT> & X,const gmm_dist_mode & dist_mode) const1560 gmm_full<eT>::internal_vec_assign(urowvec& out, const Mat<eT>& X, const gmm_dist_mode& dist_mode) const
1561 {
1562 arma_extra_debug_sigprint();
1563
1564 const uword N_dims = means.n_rows;
1565 const uword N_gaus = means.n_cols;
1566
1567 arma_debug_check( (X.n_rows != N_dims), "gmm_full::assign(): incompatible dimensions" );
1568
1569 const uword X_n_cols = (N_gaus > 0) ? X.n_cols : 0;
1570
1571 out.set_size(1,X_n_cols);
1572
1573 uword* out_mem = out.memptr();
1574
1575 if(dist_mode == eucl_dist)
1576 {
1577 #if defined(ARMA_USE_OPENMP)
1578 {
1579 #pragma omp parallel for schedule(static)
1580 for(uword i=0; i<X_n_cols; ++i)
1581 {
1582 const eT* X_colptr = X.colptr(i);
1583
1584 eT best_dist = Datum<eT>::inf;
1585 uword best_g = 0;
1586
1587 for(uword g=0; g<N_gaus; ++g)
1588 {
1589 const eT tmp_dist = distance<eT,1>::eval(N_dims, X_colptr, means.colptr(g), X_colptr);
1590
1591 if(tmp_dist <= best_dist) { best_dist = tmp_dist; best_g = g; }
1592 }
1593
1594 out_mem[i] = best_g;
1595 }
1596 }
1597 #else
1598 {
1599 for(uword i=0; i<X_n_cols; ++i)
1600 {
1601 const eT* X_colptr = X.colptr(i);
1602
1603 eT best_dist = Datum<eT>::inf;
1604 uword best_g = 0;
1605
1606 for(uword g=0; g<N_gaus; ++g)
1607 {
1608 const eT tmp_dist = distance<eT,1>::eval(N_dims, X_colptr, means.colptr(g), X_colptr);
1609
1610 if(tmp_dist <= best_dist) { best_dist = tmp_dist; best_g = g; }
1611 }
1612
1613 out_mem[i] = best_g;
1614 }
1615 }
1616 #endif
1617 }
1618 else
1619 if(dist_mode == prob_dist)
1620 {
1621 #if defined(ARMA_USE_OPENMP)
1622 {
1623 const umat boundaries = internal_gen_boundaries(X_n_cols);
1624
1625 const uword n_threads = boundaries.n_cols;
1626
1627 const eT* log_hefts_mem = log_hefts.memptr();
1628
1629 #pragma omp parallel for schedule(static)
1630 for(uword t=0; t < n_threads; ++t)
1631 {
1632 const uword start_index = boundaries.at(0,t);
1633 const uword end_index = boundaries.at(1,t);
1634
1635 for(uword i=start_index; i <= end_index; ++i)
1636 {
1637 const eT* X_colptr = X.colptr(i);
1638
1639 eT best_p = -Datum<eT>::inf;
1640 uword best_g = 0;
1641
1642 for(uword g=0; g<N_gaus; ++g)
1643 {
1644 const eT tmp_p = internal_scalar_log_p(X_colptr, g) + log_hefts_mem[g];
1645
1646 if(tmp_p >= best_p) { best_p = tmp_p; best_g = g; }
1647 }
1648
1649 out_mem[i] = best_g;
1650 }
1651 }
1652 }
1653 #else
1654 {
1655 const eT* log_hefts_mem = log_hefts.memptr();
1656
1657 for(uword i=0; i<X_n_cols; ++i)
1658 {
1659 const eT* X_colptr = X.colptr(i);
1660
1661 eT best_p = -Datum<eT>::inf;
1662 uword best_g = 0;
1663
1664 for(uword g=0; g<N_gaus; ++g)
1665 {
1666 const eT tmp_p = internal_scalar_log_p(X_colptr, g) + log_hefts_mem[g];
1667
1668 if(tmp_p >= best_p) { best_p = tmp_p; best_g = g; }
1669 }
1670
1671 out_mem[i] = best_g;
1672 }
1673 }
1674 #endif
1675 }
1676 else
1677 {
1678 arma_debug_check(true, "gmm_full::assign(): unsupported distance mode");
1679 }
1680 }
1681
1682
1683
1684
1685 template<typename eT>
1686 inline
1687 void
internal_raw_hist(urowvec & hist,const Mat<eT> & X,const gmm_dist_mode & dist_mode) const1688 gmm_full<eT>::internal_raw_hist(urowvec& hist, const Mat<eT>& X, const gmm_dist_mode& dist_mode) const
1689 {
1690 arma_extra_debug_sigprint();
1691
1692 const uword N_dims = means.n_rows;
1693 const uword N_gaus = means.n_cols;
1694
1695 const uword X_n_cols = X.n_cols;
1696
1697 hist.zeros(N_gaus);
1698
1699 if(N_gaus == 0) { return; }
1700
1701 #if defined(ARMA_USE_OPENMP)
1702 {
1703 const umat boundaries = internal_gen_boundaries(X_n_cols);
1704
1705 const uword n_threads = boundaries.n_cols;
1706
1707 field<urowvec> thread_hist(n_threads);
1708
1709 for(uword t=0; t < n_threads; ++t) { thread_hist(t).zeros(N_gaus); }
1710
1711
1712 if(dist_mode == eucl_dist)
1713 {
1714 #pragma omp parallel for schedule(static)
1715 for(uword t=0; t < n_threads; ++t)
1716 {
1717 uword* thread_hist_mem = thread_hist(t).memptr();
1718
1719 const uword start_index = boundaries.at(0,t);
1720 const uword end_index = boundaries.at(1,t);
1721
1722 for(uword i=start_index; i <= end_index; ++i)
1723 {
1724 const eT* X_colptr = X.colptr(i);
1725
1726 eT best_dist = Datum<eT>::inf;
1727 uword best_g = 0;
1728
1729 for(uword g=0; g < N_gaus; ++g)
1730 {
1731 const eT tmp_dist = distance<eT,1>::eval(N_dims, X_colptr, means.colptr(g), X_colptr);
1732
1733 if(tmp_dist <= best_dist) { best_dist = tmp_dist; best_g = g; }
1734 }
1735
1736 thread_hist_mem[best_g]++;
1737 }
1738 }
1739 }
1740 else
1741 if(dist_mode == prob_dist)
1742 {
1743 const eT* log_hefts_mem = log_hefts.memptr();
1744
1745 #pragma omp parallel for schedule(static)
1746 for(uword t=0; t < n_threads; ++t)
1747 {
1748 uword* thread_hist_mem = thread_hist(t).memptr();
1749
1750 const uword start_index = boundaries.at(0,t);
1751 const uword end_index = boundaries.at(1,t);
1752
1753 for(uword i=start_index; i <= end_index; ++i)
1754 {
1755 const eT* X_colptr = X.colptr(i);
1756
1757 eT best_p = -Datum<eT>::inf;
1758 uword best_g = 0;
1759
1760 for(uword g=0; g < N_gaus; ++g)
1761 {
1762 const eT tmp_p = internal_scalar_log_p(X_colptr, g) + log_hefts_mem[g];
1763
1764 if(tmp_p >= best_p) { best_p = tmp_p; best_g = g; }
1765 }
1766
1767 thread_hist_mem[best_g]++;
1768 }
1769 }
1770 }
1771
1772 // reduction
1773 for(uword t=0; t < n_threads; ++t)
1774 {
1775 hist += thread_hist(t);
1776 }
1777 }
1778 #else
1779 {
1780 uword* hist_mem = hist.memptr();
1781
1782 if(dist_mode == eucl_dist)
1783 {
1784 for(uword i=0; i<X_n_cols; ++i)
1785 {
1786 const eT* X_colptr = X.colptr(i);
1787
1788 eT best_dist = Datum<eT>::inf;
1789 uword best_g = 0;
1790
1791 for(uword g=0; g < N_gaus; ++g)
1792 {
1793 const eT tmp_dist = distance<eT,1>::eval(N_dims, X_colptr, means.colptr(g), X_colptr);
1794
1795 if(tmp_dist <= best_dist) { best_dist = tmp_dist; best_g = g; }
1796 }
1797
1798 hist_mem[best_g]++;
1799 }
1800 }
1801 else
1802 if(dist_mode == prob_dist)
1803 {
1804 const eT* log_hefts_mem = log_hefts.memptr();
1805
1806 for(uword i=0; i<X_n_cols; ++i)
1807 {
1808 const eT* X_colptr = X.colptr(i);
1809
1810 eT best_p = -Datum<eT>::inf;
1811 uword best_g = 0;
1812
1813 for(uword g=0; g < N_gaus; ++g)
1814 {
1815 const eT tmp_p = internal_scalar_log_p(X_colptr, g) + log_hefts_mem[g];
1816
1817 if(tmp_p >= best_p) { best_p = tmp_p; best_g = g; }
1818 }
1819
1820 hist_mem[best_g]++;
1821 }
1822 }
1823 }
1824 #endif
1825 }
1826
1827
1828
1829 template<typename eT>
1830 template<uword dist_id>
1831 inline
1832 void
generate_initial_means(const Mat<eT> & X,const gmm_seed_mode & seed_mode)1833 gmm_full<eT>::generate_initial_means(const Mat<eT>& X, const gmm_seed_mode& seed_mode)
1834 {
1835 arma_extra_debug_sigprint();
1836
1837 const uword N_dims = means.n_rows;
1838 const uword N_gaus = means.n_cols;
1839
1840 if( (seed_mode == static_subset) || (seed_mode == random_subset) )
1841 {
1842 uvec initial_indices;
1843
1844 if(seed_mode == static_subset) { initial_indices = linspace<uvec>(0, X.n_cols-1, N_gaus); }
1845 else if(seed_mode == random_subset) { initial_indices = randperm<uvec>(X.n_cols, N_gaus); }
1846
1847 // initial_indices.print("initial_indices:");
1848
1849 access::rw(means) = X.cols(initial_indices);
1850 }
1851 else
1852 if( (seed_mode == static_spread) || (seed_mode == random_spread) )
1853 {
1854 // going through all of the samples can be extremely time consuming;
1855 // instead, if there are enough samples, randomly choose samples with probability 0.1
1856
1857 const bool use_sampling = ((X.n_cols/uword(100)) > N_gaus);
1858 const uword step = (use_sampling) ? uword(10) : uword(1);
1859
1860 uword start_index = 0;
1861
1862 if(seed_mode == static_spread) { start_index = X.n_cols / 2; }
1863 else if(seed_mode == random_spread) { start_index = as_scalar(randi<uvec>(1, distr_param(0,X.n_cols-1))); }
1864
1865 access::rw(means).col(0) = X.unsafe_col(start_index);
1866
1867 const eT* mah_aux_mem = mah_aux.memptr();
1868
1869 running_stat<double> rs;
1870
1871 for(uword g=1; g < N_gaus; ++g)
1872 {
1873 eT max_dist = eT(0);
1874 uword best_i = uword(0);
1875 uword start_i = uword(0);
1876
1877 if(use_sampling)
1878 {
1879 uword start_i_proposed = uword(0);
1880
1881 if(seed_mode == static_spread) { start_i_proposed = g % uword(10); }
1882 if(seed_mode == random_spread) { start_i_proposed = as_scalar(randi<uvec>(1, distr_param(0,9))); }
1883
1884 if(start_i_proposed < X.n_cols) { start_i = start_i_proposed; }
1885 }
1886
1887
1888 for(uword i=start_i; i < X.n_cols; i += step)
1889 {
1890 rs.reset();
1891
1892 const eT* X_colptr = X.colptr(i);
1893
1894 bool ignore_i = false;
1895
1896 // find the average distance between sample i and the means so far
1897 for(uword h = 0; h < g; ++h)
1898 {
1899 const eT dist = distance<eT,dist_id>::eval(N_dims, X_colptr, means.colptr(h), mah_aux_mem);
1900
1901 // ignore sample already selected as a mean
1902 if(dist == eT(0)) { ignore_i = true; break; }
1903 else { rs(dist); }
1904 }
1905
1906 if( (rs.mean() >= max_dist) && (ignore_i == false))
1907 {
1908 max_dist = eT(rs.mean()); best_i = i;
1909 }
1910 }
1911
1912 // set the mean to the sample that is the furthest away from the means so far
1913 access::rw(means).col(g) = X.unsafe_col(best_i);
1914 }
1915 }
1916
1917 // get_cout_stream() << "generate_initial_means():" << '\n';
1918 // means.print();
1919 }
1920
1921
1922
1923 template<typename eT>
1924 template<uword dist_id>
1925 inline
1926 void
generate_initial_params(const Mat<eT> & X,const eT var_floor)1927 gmm_full<eT>::generate_initial_params(const Mat<eT>& X, const eT var_floor)
1928 {
1929 arma_extra_debug_sigprint();
1930
1931 const uword N_dims = means.n_rows;
1932 const uword N_gaus = means.n_cols;
1933
1934 const eT* mah_aux_mem = mah_aux.memptr();
1935
1936 const uword X_n_cols = X.n_cols;
1937
1938 if(X_n_cols == 0) { return; }
1939
1940 // as the covariances are calculated via accumulators,
1941 // the means also need to be calculated via accumulators to ensure numerical consistency
1942
1943 Mat<eT> acc_means(N_dims, N_gaus);
1944 Mat<eT> acc_dcovs(N_dims, N_gaus);
1945
1946 Row<uword> acc_hefts(N_gaus, arma_zeros_indicator());
1947
1948 uword* acc_hefts_mem = acc_hefts.memptr();
1949
1950 #if defined(ARMA_USE_OPENMP)
1951 {
1952 const umat boundaries = internal_gen_boundaries(X_n_cols);
1953
1954 const uword n_threads = boundaries.n_cols;
1955
1956 field< Mat<eT> > t_acc_means(n_threads);
1957 field< Mat<eT> > t_acc_dcovs(n_threads);
1958 field< Row<uword> > t_acc_hefts(n_threads);
1959
1960 for(uword t=0; t < n_threads; ++t)
1961 {
1962 t_acc_means(t).zeros(N_dims, N_gaus);
1963 t_acc_dcovs(t).zeros(N_dims, N_gaus);
1964 t_acc_hefts(t).zeros(N_gaus);
1965 }
1966
1967 #pragma omp parallel for schedule(static)
1968 for(uword t=0; t < n_threads; ++t)
1969 {
1970 uword* t_acc_hefts_mem = t_acc_hefts(t).memptr();
1971
1972 const uword start_index = boundaries.at(0,t);
1973 const uword end_index = boundaries.at(1,t);
1974
1975 for(uword i=start_index; i <= end_index; ++i)
1976 {
1977 const eT* X_colptr = X.colptr(i);
1978
1979 eT min_dist = Datum<eT>::inf;
1980 uword best_g = 0;
1981
1982 for(uword g=0; g<N_gaus; ++g)
1983 {
1984 const eT dist = distance<eT,dist_id>::eval(N_dims, X_colptr, means.colptr(g), mah_aux_mem);
1985
1986 if(dist < min_dist) { min_dist = dist; best_g = g; }
1987 }
1988
1989 eT* t_acc_mean = t_acc_means(t).colptr(best_g);
1990 eT* t_acc_dcov = t_acc_dcovs(t).colptr(best_g);
1991
1992 for(uword d=0; d<N_dims; ++d)
1993 {
1994 const eT x_d = X_colptr[d];
1995
1996 t_acc_mean[d] += x_d;
1997 t_acc_dcov[d] += x_d*x_d;
1998 }
1999
2000 t_acc_hefts_mem[best_g]++;
2001 }
2002 }
2003
2004 // reduction
2005 acc_means = t_acc_means(0);
2006 acc_dcovs = t_acc_dcovs(0);
2007 acc_hefts = t_acc_hefts(0);
2008
2009 for(uword t=1; t < n_threads; ++t)
2010 {
2011 acc_means += t_acc_means(t);
2012 acc_dcovs += t_acc_dcovs(t);
2013 acc_hefts += t_acc_hefts(t);
2014 }
2015 }
2016 #else
2017 {
2018 for(uword i=0; i<X_n_cols; ++i)
2019 {
2020 const eT* X_colptr = X.colptr(i);
2021
2022 eT min_dist = Datum<eT>::inf;
2023 uword best_g = 0;
2024
2025 for(uword g=0; g<N_gaus; ++g)
2026 {
2027 const eT dist = distance<eT,dist_id>::eval(N_dims, X_colptr, means.colptr(g), mah_aux_mem);
2028
2029 if(dist < min_dist) { min_dist = dist; best_g = g; }
2030 }
2031
2032 eT* acc_mean = acc_means.colptr(best_g);
2033 eT* acc_dcov = acc_dcovs.colptr(best_g);
2034
2035 for(uword d=0; d<N_dims; ++d)
2036 {
2037 const eT x_d = X_colptr[d];
2038
2039 acc_mean[d] += x_d;
2040 acc_dcov[d] += x_d*x_d;
2041 }
2042
2043 acc_hefts_mem[best_g]++;
2044 }
2045 }
2046 #endif
2047
2048 eT* hefts_mem = access::rw(hefts).memptr();
2049
2050 for(uword g=0; g<N_gaus; ++g)
2051 {
2052 const eT* acc_mean = acc_means.colptr(g);
2053 const eT* acc_dcov = acc_dcovs.colptr(g);
2054 const uword acc_heft = acc_hefts_mem[g];
2055
2056 eT* mean = access::rw(means).colptr(g);
2057
2058 Mat<eT>& fcov = access::rw(fcovs).slice(g);
2059 fcov.zeros();
2060
2061 for(uword d=0; d<N_dims; ++d)
2062 {
2063 const eT tmp = acc_mean[d] / eT(acc_heft);
2064
2065 mean[d] = (acc_heft >= 1) ? tmp : eT(0);
2066 fcov.at(d,d) = (acc_heft >= 2) ? eT((acc_dcov[d] / eT(acc_heft)) - (tmp*tmp)) : eT(var_floor);
2067 }
2068
2069 hefts_mem[g] = eT(acc_heft) / eT(X_n_cols);
2070 }
2071
2072 em_fix_params(var_floor);
2073 }
2074
2075
2076
2077 //! multi-threaded implementation of k-means, inspired by MapReduce
2078 template<typename eT>
2079 template<uword dist_id>
2080 inline
2081 bool
km_iterate(const Mat<eT> & X,const uword max_iter,const bool verbose)2082 gmm_full<eT>::km_iterate(const Mat<eT>& X, const uword max_iter, const bool verbose)
2083 {
2084 arma_extra_debug_sigprint();
2085
2086 if(verbose)
2087 {
2088 get_cout_stream().unsetf(ios::showbase);
2089 get_cout_stream().unsetf(ios::uppercase);
2090 get_cout_stream().unsetf(ios::showpos);
2091 get_cout_stream().unsetf(ios::scientific);
2092
2093 get_cout_stream().setf(ios::right);
2094 get_cout_stream().setf(ios::fixed);
2095 }
2096
2097 const uword X_n_cols = X.n_cols;
2098
2099 if(X_n_cols == 0) { return true; }
2100
2101 const uword N_dims = means.n_rows;
2102 const uword N_gaus = means.n_cols;
2103
2104 const eT* mah_aux_mem = mah_aux.memptr();
2105
2106 Mat<eT> acc_means(N_dims, N_gaus, arma_zeros_indicator());
2107 Row<uword> acc_hefts( N_gaus, arma_zeros_indicator());
2108 Row<uword> last_indx( N_gaus, arma_zeros_indicator());
2109
2110 Mat<eT> new_means = means;
2111 Mat<eT> old_means = means;
2112
2113 running_mean_scalar<eT> rs_delta;
2114
2115 #if defined(ARMA_USE_OPENMP)
2116 const umat boundaries = internal_gen_boundaries(X_n_cols);
2117 const uword n_threads = boundaries.n_cols;
2118
2119 field< Mat<eT> > t_acc_means(n_threads);
2120 field< Row<uword> > t_acc_hefts(n_threads);
2121 field< Row<uword> > t_last_indx(n_threads);
2122 #else
2123 const uword n_threads = 1;
2124 #endif
2125
2126 if(verbose) { get_cout_stream() << "gmm_full::learn(): k-means: n_threads: " << n_threads << '\n'; get_cout_stream().flush(); }
2127
2128 for(uword iter=1; iter <= max_iter; ++iter)
2129 {
2130 #if defined(ARMA_USE_OPENMP)
2131 {
2132 for(uword t=0; t < n_threads; ++t)
2133 {
2134 t_acc_means(t).zeros(N_dims, N_gaus);
2135 t_acc_hefts(t).zeros(N_gaus);
2136 t_last_indx(t).zeros(N_gaus);
2137 }
2138
2139 #pragma omp parallel for schedule(static)
2140 for(uword t=0; t < n_threads; ++t)
2141 {
2142 Mat<eT>& t_acc_means_t = t_acc_means(t);
2143 uword* t_acc_hefts_mem = t_acc_hefts(t).memptr();
2144 uword* t_last_indx_mem = t_last_indx(t).memptr();
2145
2146 const uword start_index = boundaries.at(0,t);
2147 const uword end_index = boundaries.at(1,t);
2148
2149 for(uword i=start_index; i <= end_index; ++i)
2150 {
2151 const eT* X_colptr = X.colptr(i);
2152
2153 eT min_dist = Datum<eT>::inf;
2154 uword best_g = 0;
2155
2156 for(uword g=0; g<N_gaus; ++g)
2157 {
2158 const eT dist = distance<eT,dist_id>::eval(N_dims, X_colptr, old_means.colptr(g), mah_aux_mem);
2159
2160 if(dist < min_dist) { min_dist = dist; best_g = g; }
2161 }
2162
2163 eT* t_acc_mean = t_acc_means_t.colptr(best_g);
2164
2165 for(uword d=0; d<N_dims; ++d) { t_acc_mean[d] += X_colptr[d]; }
2166
2167 t_acc_hefts_mem[best_g]++;
2168 t_last_indx_mem[best_g] = i;
2169 }
2170 }
2171
2172 // reduction
2173
2174 acc_means = t_acc_means(0);
2175 acc_hefts = t_acc_hefts(0);
2176
2177 for(uword t=1; t < n_threads; ++t)
2178 {
2179 acc_means += t_acc_means(t);
2180 acc_hefts += t_acc_hefts(t);
2181 }
2182
2183 for(uword g=0; g < N_gaus; ++g)
2184 for(uword t=0; t < n_threads; ++t)
2185 {
2186 if( t_acc_hefts(t)(g) >= 1 ) { last_indx(g) = t_last_indx(t)(g); }
2187 }
2188 }
2189 #else
2190 {
2191 uword* acc_hefts_mem = acc_hefts.memptr();
2192 uword* last_indx_mem = last_indx.memptr();
2193
2194 for(uword i=0; i < X_n_cols; ++i)
2195 {
2196 const eT* X_colptr = X.colptr(i);
2197
2198 eT min_dist = Datum<eT>::inf;
2199 uword best_g = 0;
2200
2201 for(uword g=0; g<N_gaus; ++g)
2202 {
2203 const eT dist = distance<eT,dist_id>::eval(N_dims, X_colptr, old_means.colptr(g), mah_aux_mem);
2204
2205 if(dist < min_dist) { min_dist = dist; best_g = g; }
2206 }
2207
2208 eT* acc_mean = acc_means.colptr(best_g);
2209
2210 for(uword d=0; d<N_dims; ++d) { acc_mean[d] += X_colptr[d]; }
2211
2212 acc_hefts_mem[best_g]++;
2213 last_indx_mem[best_g] = i;
2214 }
2215 }
2216 #endif
2217
2218 // generate new means
2219
2220 uword* acc_hefts_mem = acc_hefts.memptr();
2221
2222 for(uword g=0; g < N_gaus; ++g)
2223 {
2224 const eT* acc_mean = acc_means.colptr(g);
2225 const uword acc_heft = acc_hefts_mem[g];
2226
2227 eT* new_mean = access::rw(new_means).colptr(g);
2228
2229 for(uword d=0; d<N_dims; ++d)
2230 {
2231 new_mean[d] = (acc_heft >= 1) ? (acc_mean[d] / eT(acc_heft)) : eT(0);
2232 }
2233 }
2234
2235
2236 // heuristics to resurrect dead means
2237
2238 const uvec dead_gs = find(acc_hefts == uword(0));
2239
2240 if(dead_gs.n_elem > 0)
2241 {
2242 if(verbose) { get_cout_stream() << "gmm_full::learn(): k-means: recovering from dead means\n"; get_cout_stream().flush(); }
2243
2244 uword* last_indx_mem = last_indx.memptr();
2245
2246 const uvec live_gs = sort( find(acc_hefts >= uword(2)), "descend" );
2247
2248 if(live_gs.n_elem == 0) { return false; }
2249
2250 uword live_gs_count = 0;
2251
2252 for(uword dead_gs_count = 0; dead_gs_count < dead_gs.n_elem; ++dead_gs_count)
2253 {
2254 const uword dead_g_id = dead_gs(dead_gs_count);
2255
2256 uword proposed_i = 0;
2257
2258 if(live_gs_count < live_gs.n_elem)
2259 {
2260 const uword live_g_id = live_gs(live_gs_count); ++live_gs_count;
2261
2262 if(live_g_id == dead_g_id) { return false; }
2263
2264 // recover by using a sample from a known good mean
2265 proposed_i = last_indx_mem[live_g_id];
2266 }
2267 else
2268 {
2269 // recover by using a randomly seleced sample (last resort)
2270 proposed_i = as_scalar(randi<uvec>(1, distr_param(0,X_n_cols-1)));
2271 }
2272
2273 if(proposed_i >= X_n_cols) { return false; }
2274
2275 new_means.col(dead_g_id) = X.col(proposed_i);
2276 }
2277 }
2278
2279 rs_delta.reset();
2280
2281 for(uword g=0; g < N_gaus; ++g)
2282 {
2283 rs_delta( distance<eT,dist_id>::eval(N_dims, old_means.colptr(g), new_means.colptr(g), mah_aux_mem) );
2284 }
2285
2286 if(verbose)
2287 {
2288 get_cout_stream() << "gmm_full::learn(): k-means: iteration: ";
2289 get_cout_stream().unsetf(ios::scientific);
2290 get_cout_stream().setf(ios::fixed);
2291 get_cout_stream().width(std::streamsize(4));
2292 get_cout_stream() << iter;
2293 get_cout_stream() << " delta: ";
2294 get_cout_stream().unsetf(ios::fixed);
2295 //get_cout_stream().setf(ios::scientific);
2296 get_cout_stream() << rs_delta.mean() << '\n';
2297 get_cout_stream().flush();
2298 }
2299
2300 arma::swap(old_means, new_means);
2301
2302 if(rs_delta.mean() <= Datum<eT>::eps) { break; }
2303 }
2304
2305 access::rw(means) = old_means;
2306
2307 if(means.is_finite() == false) { return false; }
2308
2309 return true;
2310 }
2311
2312
2313
2314 //! multi-threaded implementation of Expectation-Maximisation, inspired by MapReduce
2315 template<typename eT>
2316 inline
2317 bool
em_iterate(const Mat<eT> & X,const uword max_iter,const eT var_floor,const bool verbose)2318 gmm_full<eT>::em_iterate(const Mat<eT>& X, const uword max_iter, const eT var_floor, const bool verbose)
2319 {
2320 arma_extra_debug_sigprint();
2321
2322 const uword N_dims = means.n_rows;
2323 const uword N_gaus = means.n_cols;
2324
2325 if(verbose)
2326 {
2327 get_cout_stream().unsetf(ios::showbase);
2328 get_cout_stream().unsetf(ios::uppercase);
2329 get_cout_stream().unsetf(ios::showpos);
2330 get_cout_stream().unsetf(ios::scientific);
2331
2332 get_cout_stream().setf(ios::right);
2333 get_cout_stream().setf(ios::fixed);
2334 }
2335
2336 const umat boundaries = internal_gen_boundaries(X.n_cols);
2337
2338 const uword n_threads = boundaries.n_cols;
2339
2340 field< Mat<eT> > t_acc_means(n_threads);
2341 field< Cube<eT> > t_acc_fcovs(n_threads);
2342
2343 field< Col<eT> > t_acc_norm_lhoods(n_threads);
2344 field< Col<eT> > t_gaus_log_lhoods(n_threads);
2345
2346 Col<eT> t_progress_log_lhood(n_threads, arma_nozeros_indicator());
2347
2348 for(uword t=0; t<n_threads; t++)
2349 {
2350 t_acc_means[t].set_size(N_dims, N_gaus);
2351 t_acc_fcovs[t].set_size(N_dims, N_dims, N_gaus);
2352
2353 t_acc_norm_lhoods[t].set_size(N_gaus);
2354 t_gaus_log_lhoods[t].set_size(N_gaus);
2355 }
2356
2357
2358 if(verbose)
2359 {
2360 get_cout_stream() << "gmm_full::learn(): EM: n_threads: " << n_threads << '\n';
2361 }
2362
2363 eT old_avg_log_p = -Datum<eT>::inf;
2364
2365 const bool calc_chol = false;
2366
2367 for(uword iter=1; iter <= max_iter; ++iter)
2368 {
2369 init_constants(calc_chol);
2370
2371 em_update_params(X, boundaries, t_acc_means, t_acc_fcovs, t_acc_norm_lhoods, t_gaus_log_lhoods, t_progress_log_lhood, var_floor);
2372
2373 em_fix_params(var_floor);
2374
2375 const eT new_avg_log_p = accu(t_progress_log_lhood) / eT(t_progress_log_lhood.n_elem);
2376
2377 if(verbose)
2378 {
2379 get_cout_stream() << "gmm_full::learn(): EM: iteration: ";
2380 get_cout_stream().unsetf(ios::scientific);
2381 get_cout_stream().setf(ios::fixed);
2382 get_cout_stream().width(std::streamsize(4));
2383 get_cout_stream() << iter;
2384 get_cout_stream() << " avg_log_p: ";
2385 get_cout_stream().unsetf(ios::fixed);
2386 //get_cout_stream().setf(ios::scientific);
2387 get_cout_stream() << new_avg_log_p << '\n';
2388 get_cout_stream().flush();
2389 }
2390
2391 if(arma_isfinite(new_avg_log_p) == false) { return false; }
2392
2393 if(std::abs(old_avg_log_p - new_avg_log_p) <= Datum<eT>::eps) { break; }
2394
2395
2396 old_avg_log_p = new_avg_log_p;
2397 }
2398
2399
2400 for(uword g=0; g < N_gaus; ++g)
2401 {
2402 const Mat<eT>& fcov = fcovs.slice(g);
2403
2404 if(any(vectorise(fcov.diag()) <= eT(0))) { return false; }
2405 }
2406
2407 if(means.is_finite() == false) { return false; }
2408 if(fcovs.is_finite() == false) { return false; }
2409 if(hefts.is_finite() == false) { return false; }
2410
2411 return true;
2412 }
2413
2414
2415
2416
2417 template<typename eT>
2418 inline
2419 void
em_update_params(const Mat<eT> & X,const umat & boundaries,field<Mat<eT>> & t_acc_means,field<Cube<eT>> & t_acc_fcovs,field<Col<eT>> & t_acc_norm_lhoods,field<Col<eT>> & t_gaus_log_lhoods,Col<eT> & t_progress_log_lhood,const eT var_floor)2420 gmm_full<eT>::em_update_params
2421 (
2422 const Mat<eT>& X,
2423 const umat& boundaries,
2424 field< Mat<eT> >& t_acc_means,
2425 field< Cube<eT> >& t_acc_fcovs,
2426 field< Col<eT> >& t_acc_norm_lhoods,
2427 field< Col<eT> >& t_gaus_log_lhoods,
2428 Col<eT>& t_progress_log_lhood,
2429 const eT var_floor
2430 )
2431 {
2432 arma_extra_debug_sigprint();
2433
2434 const uword n_threads = boundaries.n_cols;
2435
2436
2437 // em_generate_acc() is the "map" operation, which produces partial accumulators for means, diagonal covariances and hefts
2438
2439 #if defined(ARMA_USE_OPENMP)
2440 {
2441 #pragma omp parallel for schedule(static)
2442 for(uword t=0; t<n_threads; t++)
2443 {
2444 Mat<eT>& acc_means = t_acc_means[t];
2445 Cube<eT>& acc_fcovs = t_acc_fcovs[t];
2446 Col<eT>& acc_norm_lhoods = t_acc_norm_lhoods[t];
2447 Col<eT>& gaus_log_lhoods = t_gaus_log_lhoods[t];
2448 eT& progress_log_lhood = t_progress_log_lhood[t];
2449
2450 em_generate_acc(X, boundaries.at(0,t), boundaries.at(1,t), acc_means, acc_fcovs, acc_norm_lhoods, gaus_log_lhoods, progress_log_lhood);
2451 }
2452 }
2453 #else
2454 {
2455 em_generate_acc(X, boundaries.at(0,0), boundaries.at(1,0), t_acc_means[0], t_acc_fcovs[0], t_acc_norm_lhoods[0], t_gaus_log_lhoods[0], t_progress_log_lhood[0]);
2456 }
2457 #endif
2458
2459 const uword N_dims = means.n_rows;
2460 const uword N_gaus = means.n_cols;
2461
2462 Mat<eT>& final_acc_means = t_acc_means[0];
2463 Cube<eT>& final_acc_fcovs = t_acc_fcovs[0];
2464
2465 Col<eT>& final_acc_norm_lhoods = t_acc_norm_lhoods[0];
2466
2467
2468 // the "reduce" operation, which combines the partial accumulators produced by the separate threads
2469
2470 for(uword t=1; t<n_threads; t++)
2471 {
2472 final_acc_means += t_acc_means[t];
2473 final_acc_fcovs += t_acc_fcovs[t];
2474
2475 final_acc_norm_lhoods += t_acc_norm_lhoods[t];
2476 }
2477
2478
2479 eT* hefts_mem = access::rw(hefts).memptr();
2480
2481 Mat<eT> mean_outer(N_dims, N_dims, arma_nozeros_indicator());
2482
2483
2484 //// update each component without sanity checking
2485 //for(uword g=0; g < N_gaus; ++g)
2486 // {
2487 // const eT acc_norm_lhood = (std::max)( final_acc_norm_lhoods[g], std::numeric_limits<eT>::min() );
2488 //
2489 // hefts_mem[g] = acc_norm_lhood / eT(X.n_cols);
2490 //
2491 // eT* mean_mem = access::rw(means).colptr(g);
2492 // eT* acc_mean_mem = final_acc_means.colptr(g);
2493 //
2494 // for(uword d=0; d < N_dims; ++d)
2495 // {
2496 // mean_mem[d] = acc_mean_mem[d] / acc_norm_lhood;
2497 // }
2498 //
2499 // const Col<eT> mean(mean_mem, N_dims, false, true);
2500 //
2501 // mean_outer = mean * mean.t();
2502 //
2503 // Mat<eT>& fcov = access::rw(fcovs).slice(g);
2504 // Mat<eT>& acc_fcov = final_acc_fcovs.slice(g);
2505 //
2506 // fcov = acc_fcov / acc_norm_lhood - mean_outer;
2507 // }
2508
2509
2510 // conditionally update each component; if only a subset of the hefts was updated, em_fix_params() will sanitise them
2511 for(uword g=0; g < N_gaus; ++g)
2512 {
2513 const eT acc_norm_lhood = (std::max)( final_acc_norm_lhoods[g], std::numeric_limits<eT>::min() );
2514
2515 if(arma_isfinite(acc_norm_lhood) == false) { continue; }
2516
2517 eT* acc_mean_mem = final_acc_means.colptr(g);
2518
2519 for(uword d=0; d < N_dims; ++d)
2520 {
2521 acc_mean_mem[d] /= acc_norm_lhood;
2522 }
2523
2524 const Col<eT> new_mean(acc_mean_mem, N_dims, false, true);
2525
2526 mean_outer = new_mean * new_mean.t();
2527
2528 Mat<eT>& acc_fcov = final_acc_fcovs.slice(g);
2529
2530 acc_fcov /= acc_norm_lhood;
2531 acc_fcov -= mean_outer;
2532
2533 for(uword d=0; d < N_dims; ++d)
2534 {
2535 eT& val = acc_fcov.at(d,d);
2536
2537 if(val < var_floor) { val = var_floor; }
2538 }
2539
2540 if(acc_fcov.is_finite() == false) { continue; }
2541
2542 eT log_det_val = eT(0);
2543 eT log_det_sign = eT(0);
2544
2545 const bool log_det_status = log_det(log_det_val, log_det_sign, acc_fcov);
2546
2547 const bool log_det_ok = ( log_det_status && (arma_isfinite(log_det_val)) && (log_det_sign > eT(0)) );
2548
2549 const bool inv_ok = (log_det_ok) ? bool(auxlib::inv_sympd(mean_outer, acc_fcov)) : bool(false); // mean_outer is used as a junk matrix
2550
2551 if(log_det_ok && inv_ok)
2552 {
2553 hefts_mem[g] = acc_norm_lhood / eT(X.n_cols);
2554
2555 eT* mean_mem = access::rw(means).colptr(g);
2556
2557 for(uword d=0; d < N_dims; ++d)
2558 {
2559 mean_mem[d] = acc_mean_mem[d];
2560 }
2561
2562 Mat<eT>& fcov = access::rw(fcovs).slice(g);
2563
2564 fcov = acc_fcov;
2565 }
2566 }
2567 }
2568
2569
2570
2571 template<typename eT>
2572 inline
2573 void
em_generate_acc(const Mat<eT> & X,const uword start_index,const uword end_index,Mat<eT> & acc_means,Cube<eT> & acc_fcovs,Col<eT> & acc_norm_lhoods,Col<eT> & gaus_log_lhoods,eT & progress_log_lhood) const2574 gmm_full<eT>::em_generate_acc
2575 (
2576 const Mat<eT>& X,
2577 const uword start_index,
2578 const uword end_index,
2579 Mat<eT>& acc_means,
2580 Cube<eT>& acc_fcovs,
2581 Col<eT>& acc_norm_lhoods,
2582 Col<eT>& gaus_log_lhoods,
2583 eT& progress_log_lhood
2584 )
2585 const
2586 {
2587 arma_extra_debug_sigprint();
2588
2589 progress_log_lhood = eT(0);
2590
2591 acc_means.zeros();
2592 acc_fcovs.zeros();
2593
2594 acc_norm_lhoods.zeros();
2595 gaus_log_lhoods.zeros();
2596
2597 const uword N_dims = means.n_rows;
2598 const uword N_gaus = means.n_cols;
2599
2600 const eT* log_hefts_mem = log_hefts.memptr();
2601 eT* gaus_log_lhoods_mem = gaus_log_lhoods.memptr();
2602
2603
2604 for(uword i=start_index; i <= end_index; i++)
2605 {
2606 const eT* x = X.colptr(i);
2607
2608 for(uword g=0; g < N_gaus; ++g)
2609 {
2610 gaus_log_lhoods_mem[g] = internal_scalar_log_p(x, g) + log_hefts_mem[g];
2611 }
2612
2613 eT log_lhood_sum = gaus_log_lhoods_mem[0];
2614
2615 for(uword g=1; g < N_gaus; ++g)
2616 {
2617 log_lhood_sum = log_add_exp(log_lhood_sum, gaus_log_lhoods_mem[g]);
2618 }
2619
2620 progress_log_lhood += log_lhood_sum;
2621
2622 for(uword g=0; g < N_gaus; ++g)
2623 {
2624 const eT norm_lhood = std::exp(gaus_log_lhoods_mem[g] - log_lhood_sum);
2625
2626 acc_norm_lhoods[g] += norm_lhood;
2627
2628 eT* acc_mean_mem = acc_means.colptr(g);
2629
2630 for(uword d=0; d < N_dims; ++d)
2631 {
2632 acc_mean_mem[d] += x[d] * norm_lhood;
2633 }
2634
2635 Mat<eT>& acc_fcov = access::rw(acc_fcovs).slice(g);
2636
2637 // specialised version of acc_fcov += norm_lhood * (xx * xx.t());
2638
2639 for(uword d=0; d < N_dims; ++d)
2640 {
2641 const uword dp1 = d+1;
2642
2643 const eT xd = x[d];
2644
2645 eT* acc_fcov_col_d = acc_fcov.colptr(d) + d;
2646 eT* acc_fcov_row_d = &(acc_fcov.at(d,dp1));
2647
2648 (*acc_fcov_col_d) += norm_lhood * (xd * xd); acc_fcov_col_d++;
2649
2650 for(uword e=dp1; e < N_dims; ++e)
2651 {
2652 const eT val = norm_lhood * (xd * x[e]);
2653
2654 (*acc_fcov_col_d) += val; acc_fcov_col_d++;
2655 (*acc_fcov_row_d) += val; acc_fcov_row_d += N_dims;
2656 }
2657 }
2658 }
2659 }
2660
2661 progress_log_lhood /= eT((end_index - start_index) + 1);
2662 }
2663
2664
2665
2666 template<typename eT>
2667 inline
2668 void
em_fix_params(const eT var_floor)2669 gmm_full<eT>::em_fix_params(const eT var_floor)
2670 {
2671 arma_extra_debug_sigprint();
2672
2673 const uword N_dims = means.n_rows;
2674 const uword N_gaus = means.n_cols;
2675
2676 const eT var_ceiling = std::numeric_limits<eT>::max();
2677
2678 for(uword g=0; g < N_gaus; ++g)
2679 {
2680 Mat<eT>& fcov = access::rw(fcovs).slice(g);
2681
2682 for(uword d=0; d < N_dims; ++d)
2683 {
2684 eT& var_val = fcov.at(d,d);
2685
2686 if(var_val < var_floor ) { var_val = var_floor; }
2687 else if(var_val > var_ceiling) { var_val = var_ceiling; }
2688 else if(arma_isnan(var_val) ) { var_val = eT(1); }
2689 }
2690 }
2691
2692
2693 eT* hefts_mem = access::rw(hefts).memptr();
2694
2695 for(uword g1=0; g1 < N_gaus; ++g1)
2696 {
2697 if(hefts_mem[g1] > eT(0))
2698 {
2699 const eT* means_colptr_g1 = means.colptr(g1);
2700
2701 for(uword g2=(g1+1); g2 < N_gaus; ++g2)
2702 {
2703 if( (hefts_mem[g2] > eT(0)) && (std::abs(hefts_mem[g1] - hefts_mem[g2]) <= std::numeric_limits<eT>::epsilon()) )
2704 {
2705 const eT dist = distance<eT,1>::eval(N_dims, means_colptr_g1, means.colptr(g2), means_colptr_g1);
2706
2707 if(dist == eT(0)) { hefts_mem[g2] = eT(0); }
2708 }
2709 }
2710 }
2711 }
2712
2713 const eT heft_floor = std::numeric_limits<eT>::min();
2714 const eT heft_initial = eT(1) / eT(N_gaus);
2715
2716 for(uword i=0; i < N_gaus; ++i)
2717 {
2718 eT& heft_val = hefts_mem[i];
2719
2720 if(heft_val < heft_floor) { heft_val = heft_floor; }
2721 else if(heft_val > eT(1) ) { heft_val = eT(1); }
2722 else if(arma_isnan(heft_val) ) { heft_val = heft_initial; }
2723 }
2724
2725 const eT heft_sum = accu(hefts);
2726
2727 if((heft_sum < (eT(1) - Datum<eT>::eps)) || (heft_sum > (eT(1) + Datum<eT>::eps))) { access::rw(hefts) /= heft_sum; }
2728 }
2729
2730
2731
2732 } // namespace gmm_priv
2733
2734
2735 //! @}
2736