1 // Copyright 2018 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "chromecast/base/statistics/weighted_moving_linear_regression.h"
6
7 #include <math.h>
8 #include <algorithm>
9
10 #include "base/check_op.h"
11 #include "base/logging.h"
12
13 namespace chromecast {
14
WeightedMovingLinearRegression(int64_t max_x_range)15 WeightedMovingLinearRegression::WeightedMovingLinearRegression(
16 int64_t max_x_range)
17 : max_x_range_(max_x_range),
18 covariance_(0),
19 slope_(0),
20 slope_variance_(0),
21 intercept_variance_(0),
22 has_estimate_(false) {
23 DCHECK_GE(max_x_range_, 0);
24 }
25
~WeightedMovingLinearRegression()26 WeightedMovingLinearRegression::~WeightedMovingLinearRegression() {}
27
AddSample(int64_t x,int64_t y,double weight)28 void WeightedMovingLinearRegression::AddSample(int64_t x,
29 int64_t y,
30 double weight) {
31 DCHECK_GE(weight, 0);
32 if (!samples_.empty())
33 DCHECK_GE(x, samples_.back().x);
34
35 UpdateSet(x, y, weight);
36 Sample sample = {x, y, weight};
37 samples_.push_back(sample);
38
39 // Remove old samples.
40 while (x - samples_.front().x > max_x_range_) {
41 const Sample& old_sample = samples_.front();
42 UpdateSet(old_sample.x, old_sample.y, -old_sample.weight);
43 samples_.pop_front();
44 }
45 DCHECK(!samples_.empty());
46
47 if (samples_.size() <= 2 || x_mean_.sum_weights() == 0 ||
48 x_mean_.variance_sum() == 0) {
49 has_estimate_ = false;
50 return;
51 }
52
53 slope_ = covariance_ / x_mean_.variance_sum();
54
55 double residual_sum_squares =
56 (covariance_ * covariance_) / x_mean_.variance_sum();
57 double mean_squared_error =
58 (y_mean_.variance_sum() - residual_sum_squares) / (samples_.size() - 2);
59
60 slope_variance_ = std::max(0.0, mean_squared_error / x_mean_.variance_sum());
61 intercept_variance_ = std::max(
62 0.0, (slope_variance_ * x_mean_.variance_sum()) / x_mean_.sum_weights());
63
64 has_estimate_ = true;
65 }
66
EstimateY(int64_t x,int64_t * y,double * error) const67 bool WeightedMovingLinearRegression::EstimateY(int64_t x,
68 int64_t* y,
69 double* error) const {
70 if (!has_estimate_)
71 return false;
72
73 double x_diff = x - x_mean_.weighted_mean();
74 double y_estimate = y_mean_.weighted_mean() + (slope_ * x_diff);
75
76 *y = static_cast<int64_t>(round(y_estimate));
77 *error = sqrt(intercept_variance_ + (slope_variance_ * x_diff * x_diff));
78 return true;
79 }
80
EstimateSlope(double * slope,double * error) const81 bool WeightedMovingLinearRegression::EstimateSlope(double* slope,
82 double* error) const {
83 if (!has_estimate_)
84 return false;
85
86 *slope = slope_;
87 *error = sqrt(slope_variance_);
88 return true;
89 }
90
UpdateSet(int64_t x,int64_t y,double weight)91 void WeightedMovingLinearRegression::UpdateSet(int64_t x,
92 int64_t y,
93 double weight) {
94 double old_y_mean = y_mean_.weighted_mean();
95 x_mean_.AddSample(x, weight);
96 y_mean_.AddSample(y, weight);
97 covariance_ += weight * (x - x_mean_.weighted_mean()) * (y - old_y_mean);
98 }
99
DumpSamples() const100 void WeightedMovingLinearRegression::DumpSamples() const {
101 for (auto sample : samples_) {
102 LOG(INFO) << "x, y, weight: " << sample.x << " " << sample.y << " "
103 << sample.weight;
104 }
105 }
106
107 } // namespace chromecast
108