1 // Copyright 2018 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "chromecast/base/statistics/weighted_moving_linear_regression.h"
6 
7 #include <math.h>
8 #include <algorithm>
9 
10 #include "base/check_op.h"
11 #include "base/logging.h"
12 
13 namespace chromecast {
14 
WeightedMovingLinearRegression(int64_t max_x_range)15 WeightedMovingLinearRegression::WeightedMovingLinearRegression(
16     int64_t max_x_range)
17     : max_x_range_(max_x_range),
18       covariance_(0),
19       slope_(0),
20       slope_variance_(0),
21       intercept_variance_(0),
22       has_estimate_(false) {
23   DCHECK_GE(max_x_range_, 0);
24 }
25 
~WeightedMovingLinearRegression()26 WeightedMovingLinearRegression::~WeightedMovingLinearRegression() {}
27 
AddSample(int64_t x,int64_t y,double weight)28 void WeightedMovingLinearRegression::AddSample(int64_t x,
29                                                int64_t y,
30                                                double weight) {
31   DCHECK_GE(weight, 0);
32   if (!samples_.empty())
33     DCHECK_GE(x, samples_.back().x);
34 
35   UpdateSet(x, y, weight);
36   Sample sample = {x, y, weight};
37   samples_.push_back(sample);
38 
39   // Remove old samples.
40   while (x - samples_.front().x > max_x_range_) {
41     const Sample& old_sample = samples_.front();
42     UpdateSet(old_sample.x, old_sample.y, -old_sample.weight);
43     samples_.pop_front();
44   }
45   DCHECK(!samples_.empty());
46 
47   if (samples_.size() <= 2 || x_mean_.sum_weights() == 0 ||
48       x_mean_.variance_sum() == 0) {
49     has_estimate_ = false;
50     return;
51   }
52 
53   slope_ = covariance_ / x_mean_.variance_sum();
54 
55   double residual_sum_squares =
56       (covariance_ * covariance_) / x_mean_.variance_sum();
57   double mean_squared_error =
58       (y_mean_.variance_sum() - residual_sum_squares) / (samples_.size() - 2);
59 
60   slope_variance_ = std::max(0.0, mean_squared_error / x_mean_.variance_sum());
61   intercept_variance_ = std::max(
62       0.0, (slope_variance_ * x_mean_.variance_sum()) / x_mean_.sum_weights());
63 
64   has_estimate_ = true;
65 }
66 
EstimateY(int64_t x,int64_t * y,double * error) const67 bool WeightedMovingLinearRegression::EstimateY(int64_t x,
68                                                int64_t* y,
69                                                double* error) const {
70   if (!has_estimate_)
71     return false;
72 
73   double x_diff = x - x_mean_.weighted_mean();
74   double y_estimate = y_mean_.weighted_mean() + (slope_ * x_diff);
75 
76   *y = static_cast<int64_t>(round(y_estimate));
77   *error = sqrt(intercept_variance_ + (slope_variance_ * x_diff * x_diff));
78   return true;
79 }
80 
EstimateSlope(double * slope,double * error) const81 bool WeightedMovingLinearRegression::EstimateSlope(double* slope,
82                                                    double* error) const {
83   if (!has_estimate_)
84     return false;
85 
86   *slope = slope_;
87   *error = sqrt(slope_variance_);
88   return true;
89 }
90 
UpdateSet(int64_t x,int64_t y,double weight)91 void WeightedMovingLinearRegression::UpdateSet(int64_t x,
92                                                int64_t y,
93                                                double weight) {
94   double old_y_mean = y_mean_.weighted_mean();
95   x_mean_.AddSample(x, weight);
96   y_mean_.AddSample(y, weight);
97   covariance_ += weight * (x - x_mean_.weighted_mean()) * (y - old_y_mean);
98 }
99 
DumpSamples() const100 void WeightedMovingLinearRegression::DumpSamples() const {
101   for (auto sample : samples_) {
102     LOG(INFO) << "x, y, weight: " << sample.x << " " << sample.y << " "
103               << sample.weight;
104   }
105 }
106 
107 }  // namespace chromecast
108