1 #pragma once
2 #ifndef CATA_TESTS_TEST_STATISTICS_H
3 #define CATA_TESTS_TEST_STATISTICS_H
4 
5 #include <algorithm>
6 #include <cmath>
7 #include <iosfwd>
8 #include <limits>
9 #include <type_traits>
10 #include <vector>
11 
12 #include "catch/catch.hpp"
13 
14 // Z-value for confidence interval
15 constexpr double Z95 = 1.96;
16 constexpr double Z99 = 2.576;
17 constexpr double Z99_9 = 3.291;
18 constexpr double Z99_99 = 3.891;
19 constexpr double Z99_999 = 4.5;
20 constexpr double Z99_999_9 = 5.0;
21 
22 // Useful to specify a range using midpoint +/- ε which is easier to parse how
23 // wide a range actually is vs just upper and lower
24 struct epsilon_threshold {
25     double midpoint;
26     double epsilon;
27 };
28 
29 // Upper/lower bound threshold useful for asymmetric thresholds
30 struct upper_lower_threshold {
31     double lower_thresh;
32     double upper_thresh;
33 };
34 
35 // we cache the margin of error so when adding a new value we must invalidate
36 // it so it gets calculated a again
37 static constexpr double invalid_err = -1;
38 
39 template<typename T>
40 class statistics
41 {
42     private:
43         int _types;
44         int _n;
45         double _sum;
46         double _error;
47         const double _Z;
48         const double _Zsq;
49         T _max;
50         T _min;
51         std::vector< T > samples;
52     public:
53         explicit statistics( const double Z = Z99_9 ) :
54             _types( 0 ), _n( 0 ), _sum( 0 ), _error( invalid_err ),
55             _Z( Z ),  _Zsq( Z * Z ), _max( std::numeric_limits<T>::min() ),
56             _min( std::numeric_limits<T>::max() ) {}
57 
new_type()58         void new_type() {
59             _types++;
60         }
add(T new_val)61         void add( T new_val ) {
62             _error = invalid_err;
63             _n++;
64             _sum += new_val;
65             _max = std::max( _max, new_val );
66             _min = std::min( _min, new_val );
67             samples.push_back( new_val );
68         }
69 
70         // Adjusted Wald error is only valid for a discrete binary test. Note
71         // because error takes into account population, it is only valid to
72         // test against the upper/lower bound.
73         //
74         // The goal here is to get the most accurate statistics about the
75         // smallest sample size feasible.  The tests end up getting run many
76         // times over a short period, so any real issue may sometimes get a
77         // false positive, but over a series of runs will get shaken out in an
78         // obvious way.
79         //
80         // Outside of this class, this should only be used for debugging
81         // purposes.
82         template<typename U = T>
83         typename std::enable_if< std::is_same< U, bool >::value, double >::type
margin_of_error()84         margin_of_error() {
85             if( _error != invalid_err ) {
86                 return _error;
87             }
88             // Implementation of outline from https://measuringu.com/ci-five-steps/
89             const double adj_numerator = ( _Zsq / 2 ) + _sum;
90             const double adj_denominator = _Zsq + _n;
91             const double adj_proportion = adj_numerator / adj_denominator;
92             const double a = adj_proportion * ( 1.0 - adj_proportion );
93             const double b = a / adj_denominator;
94             const double c = std::sqrt( b );
95             _error = c * _Z;
96             return _error;
97         }
98         // Standard error is intended to be used with continuous data samples.
99         // We're using an approximation here so it is only appropriate to use
100         // the upper/lower bound to test for reasons similar to adjusted Wald
101         // error.
102         // Outside of this class, this should only be used for debugging purposes.
103         // https://measuringu.com/ci-five-steps/
104         template<typename U = T>
105         typename std::enable_if < ! std::is_same< U, bool >::value, double >::type
margin_of_error()106         margin_of_error() {
107             if( _error != invalid_err ) {
108                 return _error;
109             }
110             const double std_err = stddev() / std::sqrt( _n );
111             _error = std_err * _Z;
112             return _error;
113         }
114 
115         /** Use to continue testing until we are sure whether the result is
116          * inside or outside the target.
117          *
118          * Returns true if the confidence interval partially overlaps the target region.
119          */
uncertain_about(const epsilon_threshold & t)120         bool uncertain_about( const epsilon_threshold &t ) {
121             return !test_threshold( t ) && // Inside target
122                    t.midpoint - t.epsilon < upper() && // Below target
123                    t.midpoint + t.epsilon > lower(); // Above target
124         }
125 
test_threshold(const epsilon_threshold & t)126         bool test_threshold( const epsilon_threshold &t ) {
127             return ( ( t.midpoint - t.epsilon ) < lower() &&
128                      ( t.midpoint + t.epsilon ) > upper() );
129         }
test_threshold(const upper_lower_threshold & t)130         bool test_threshold( const upper_lower_threshold &t ) {
131             return ( t.lower_thresh < lower() && t.upper_thresh > upper() );
132         }
upper()133         double upper() {
134             double result = avg() + margin_of_error();
135             if( std::is_same<T, bool>::value ) {
136                 result = std::min( result, 1.0 );
137             }
138             return result;
139         }
lower()140         double lower() {
141             double result = avg() - margin_of_error();
142             if( std::is_same<T, bool>::value ) {
143                 result = std::max( result, 0.0 );
144             }
145             return result;
146         }
147         // Test if some value is a member of the confidence interval of the
148         // sample
test_confidence_interval(const double v)149         bool test_confidence_interval( const double v ) const {
150             return is_within_epsilon( v, margin_of_error() );
151         }
152 
is_within_epsilon(const double v,const double epsilon)153         bool is_within_epsilon( const double v, const double epsilon ) const {
154             const double average = avg();
155             return( ( average + epsilon > v ) &&
156                     ( average - epsilon < v ) );
157         }
158         // Theoretically a one-pass formula is more efficient, however because
159         // we keep handles onto _sum and _n as class members and calculate them
160         // on the fly, a one-pass formula is unnecessary because we're already
161         // one pass here.  It may not obvious that even though we're calling
162         // the 'average()' function that's what is happening.
163         double variance( const bool sample_variance = true ) const {
164             double average = avg();
165             double sigma_acc = 0;
166 
167             for( const T v : samples ) {
168                 const double diff = v - average;
169                 sigma_acc += diff * diff;
170             }
171 
172             if( sample_variance ) {
173                 return sigma_acc / static_cast<double>( _n - 1 );
174             }
175             return sigma_acc / static_cast<double>( _n );
176         }
177         // We should only be interested in the sample deviation most of the
178         // time because we can always get more samples.  The way we use tests,
179         // we attempt to use the sample data to generalize about the
180         // population.
181         double stddev( const bool sample_deviation = true ) const {
182             return std::sqrt( variance( sample_deviation ) );
183         }
184 
types()185         int types() const {
186             return _types;
187         }
sum()188         double sum() const {
189             return _sum;
190         }
max()191         T max() const {
192             return _max;
193         }
min()194         T min() const {
195             return _min;
196         }
avg()197         double avg() const {
198             return _sum / static_cast<double>( _n );
199         }
n()200         int n() const {
201             return _n;
202         }
get_samples()203         std::vector<T> get_samples() {
204             return samples;
205         }
206 };
207 
208 class BinomialMatcher : public Catch::MatcherBase<int>
209 {
210     public:
211         BinomialMatcher( int num_samples, double p, double max_deviation );
212         bool match( const int &obs ) const override;
213         std::string describe() const override;
214     private:
215         int num_samples_;
216         double p_;
217         double max_deviation_;
218         double expected_;
219         double margin_;
220 };
221 
222 // Can be used to test that a value is a plausible observation from a binomial
223 // distribution.  Uses a normal approximation to the binomial, and permits a
224 // deviation up to max_deviation (measured in standard deviations).
225 inline BinomialMatcher IsBinomialObservation(
226     const int num_samples, const double p, const double max_deviation = Z99_99 )
227 {
228     return BinomialMatcher( num_samples, p, max_deviation );
229 }
230 
231 #endif // CATA_TESTS_TEST_STATISTICS_H
232