1package Algorithm::NaiveBayes::Model::Gaussian;
2
3use strict;
4use base qw(Algorithm::NaiveBayes);
5use Algorithm::NaiveBayes::Util qw(sum variance rescale);
6use constant Pi => 4*atan2(1, 1);
7
8sub do_add_instance {
9  my ($self, $attributes, $labels, $training_data) = @_;
10
11  foreach my $label ( @$labels ) {
12    my $mylabel = $training_data->{labels}{$label} ||= {};
13    $mylabel->{count}++;
14    while (my ($attr, $value) = each %$attributes) {
15      push @{$mylabel->{attrs}{$attr}}, $value;
16    }
17  }
18}
19
20sub do_train {
21  my ($self, $training_data) = @_;
22  my $m = {};
23
24  my $instances = $self->instances;
25  my $labels = $training_data->{labels};
26
27  while (my ($label, $data) = each %$labels) {
28    $m->{prior_probs}{$label} = log($labels->{$label}{count} / $instances);
29
30    # Calculate the mean & stddev for each label-attribute combination
31    while (my ($attr, $values) = each %{$data->{attrs}}) {
32      my $mean = sum($values) / @$values;
33      my $var  = variance($values, $mean)
34	or next;  # Can't use variance of zero
35      @{ $m->{summary}{$attr}{$label} }{'mean', 'var'} = ($mean, $var);
36    }
37  }
38  return $m;
39}
40
41sub do_predict {
42  my ($self, $m, $newattrs) = @_;
43
44  my %scores = %{$m->{prior_probs}};
45  while (my ($feature, $value) = each %$newattrs) {
46    next unless exists $m->{summary}{$feature};  # Ignore totally unseen features
47    while (my ($label, $data) = each %{$m->{summary}{$feature}}) {
48      my ($mean, $var) = @{$data}{'mean', 'var'};
49      # This is simplified from
50      #   +=  log( 1/sqrt($var*2*Pi) * exp(-($value-$mean)**2/(2*$var)) );
51      $scores{$label} -= 0.5*(log($var) + log(2*Pi) + ($value-$mean)**2/$var);
52    }
53  }
54
55  rescale(\%scores);
56
57  return \%scores;
58}
59
601;
61