1 #pragma once 2 #ifndef CATA_TESTS_TEST_STATISTICS_H 3 #define CATA_TESTS_TEST_STATISTICS_H 4 5 #include <algorithm> 6 #include <cmath> 7 #include <iosfwd> 8 #include <limits> 9 #include <type_traits> 10 #include <vector> 11 12 #include "catch/catch.hpp" 13 14 // Z-value for confidence interval 15 constexpr double Z95 = 1.96; 16 constexpr double Z99 = 2.576; 17 constexpr double Z99_9 = 3.291; 18 constexpr double Z99_99 = 3.891; 19 constexpr double Z99_999 = 4.5; 20 constexpr double Z99_999_9 = 5.0; 21 22 // Useful to specify a range using midpoint +/- ε which is easier to parse how 23 // wide a range actually is vs just upper and lower 24 struct epsilon_threshold { 25 double midpoint; 26 double epsilon; 27 }; 28 29 // Upper/lower bound threshold useful for asymmetric thresholds 30 struct upper_lower_threshold { 31 double lower_thresh; 32 double upper_thresh; 33 }; 34 35 // we cache the margin of error so when adding a new value we must invalidate 36 // it so it gets calculated a again 37 static constexpr double invalid_err = -1; 38 39 template<typename T> 40 class statistics 41 { 42 private: 43 int _types; 44 int _n; 45 double _sum; 46 double _error; 47 const double _Z; 48 const double _Zsq; 49 T _max; 50 T _min; 51 std::vector< T > samples; 52 public: 53 explicit statistics( const double Z = Z99_9 ) : 54 _types( 0 ), _n( 0 ), _sum( 0 ), _error( invalid_err ), 55 _Z( Z ), _Zsq( Z * Z ), _max( std::numeric_limits<T>::min() ), 56 _min( std::numeric_limits<T>::max() ) {} 57 new_type()58 void new_type() { 59 _types++; 60 } add(T new_val)61 void add( T new_val ) { 62 _error = invalid_err; 63 _n++; 64 _sum += new_val; 65 _max = std::max( _max, new_val ); 66 _min = std::min( _min, new_val ); 67 samples.push_back( new_val ); 68 } 69 70 // Adjusted Wald error is only valid for a discrete binary test. Note 71 // because error takes into account population, it is only valid to 72 // test against the upper/lower bound. 73 // 74 // The goal here is to get the most accurate statistics about the 75 // smallest sample size feasible. The tests end up getting run many 76 // times over a short period, so any real issue may sometimes get a 77 // false positive, but over a series of runs will get shaken out in an 78 // obvious way. 79 // 80 // Outside of this class, this should only be used for debugging 81 // purposes. 82 template<typename U = T> 83 typename std::enable_if< std::is_same< U, bool >::value, double >::type margin_of_error()84 margin_of_error() { 85 if( _error != invalid_err ) { 86 return _error; 87 } 88 // Implementation of outline from https://measuringu.com/ci-five-steps/ 89 const double adj_numerator = ( _Zsq / 2 ) + _sum; 90 const double adj_denominator = _Zsq + _n; 91 const double adj_proportion = adj_numerator / adj_denominator; 92 const double a = adj_proportion * ( 1.0 - adj_proportion ); 93 const double b = a / adj_denominator; 94 const double c = std::sqrt( b ); 95 _error = c * _Z; 96 return _error; 97 } 98 // Standard error is intended to be used with continuous data samples. 99 // We're using an approximation here so it is only appropriate to use 100 // the upper/lower bound to test for reasons similar to adjusted Wald 101 // error. 102 // Outside of this class, this should only be used for debugging purposes. 103 // https://measuringu.com/ci-five-steps/ 104 template<typename U = T> 105 typename std::enable_if < ! std::is_same< U, bool >::value, double >::type margin_of_error()106 margin_of_error() { 107 if( _error != invalid_err ) { 108 return _error; 109 } 110 const double std_err = stddev() / std::sqrt( _n ); 111 _error = std_err * _Z; 112 return _error; 113 } 114 115 /** Use to continue testing until we are sure whether the result is 116 * inside or outside the target. 117 * 118 * Returns true if the confidence interval partially overlaps the target region. 119 */ uncertain_about(const epsilon_threshold & t)120 bool uncertain_about( const epsilon_threshold &t ) { 121 return !test_threshold( t ) && // Inside target 122 t.midpoint - t.epsilon < upper() && // Below target 123 t.midpoint + t.epsilon > lower(); // Above target 124 } 125 test_threshold(const epsilon_threshold & t)126 bool test_threshold( const epsilon_threshold &t ) { 127 return ( ( t.midpoint - t.epsilon ) < lower() && 128 ( t.midpoint + t.epsilon ) > upper() ); 129 } test_threshold(const upper_lower_threshold & t)130 bool test_threshold( const upper_lower_threshold &t ) { 131 return ( t.lower_thresh < lower() && t.upper_thresh > upper() ); 132 } upper()133 double upper() { 134 double result = avg() + margin_of_error(); 135 if( std::is_same<T, bool>::value ) { 136 result = std::min( result, 1.0 ); 137 } 138 return result; 139 } lower()140 double lower() { 141 double result = avg() - margin_of_error(); 142 if( std::is_same<T, bool>::value ) { 143 result = std::max( result, 0.0 ); 144 } 145 return result; 146 } 147 // Test if some value is a member of the confidence interval of the 148 // sample test_confidence_interval(const double v)149 bool test_confidence_interval( const double v ) const { 150 return is_within_epsilon( v, margin_of_error() ); 151 } 152 is_within_epsilon(const double v,const double epsilon)153 bool is_within_epsilon( const double v, const double epsilon ) const { 154 const double average = avg(); 155 return( ( average + epsilon > v ) && 156 ( average - epsilon < v ) ); 157 } 158 // Theoretically a one-pass formula is more efficient, however because 159 // we keep handles onto _sum and _n as class members and calculate them 160 // on the fly, a one-pass formula is unnecessary because we're already 161 // one pass here. It may not obvious that even though we're calling 162 // the 'average()' function that's what is happening. 163 double variance( const bool sample_variance = true ) const { 164 double average = avg(); 165 double sigma_acc = 0; 166 167 for( const T v : samples ) { 168 const double diff = v - average; 169 sigma_acc += diff * diff; 170 } 171 172 if( sample_variance ) { 173 return sigma_acc / static_cast<double>( _n - 1 ); 174 } 175 return sigma_acc / static_cast<double>( _n ); 176 } 177 // We should only be interested in the sample deviation most of the 178 // time because we can always get more samples. The way we use tests, 179 // we attempt to use the sample data to generalize about the 180 // population. 181 double stddev( const bool sample_deviation = true ) const { 182 return std::sqrt( variance( sample_deviation ) ); 183 } 184 types()185 int types() const { 186 return _types; 187 } sum()188 double sum() const { 189 return _sum; 190 } max()191 T max() const { 192 return _max; 193 } min()194 T min() const { 195 return _min; 196 } avg()197 double avg() const { 198 return _sum / static_cast<double>( _n ); 199 } n()200 int n() const { 201 return _n; 202 } get_samples()203 std::vector<T> get_samples() { 204 return samples; 205 } 206 }; 207 208 class BinomialMatcher : public Catch::MatcherBase<int> 209 { 210 public: 211 BinomialMatcher( int num_samples, double p, double max_deviation ); 212 bool match( const int &obs ) const override; 213 std::string describe() const override; 214 private: 215 int num_samples_; 216 double p_; 217 double max_deviation_; 218 double expected_; 219 double margin_; 220 }; 221 222 // Can be used to test that a value is a plausible observation from a binomial 223 // distribution. Uses a normal approximation to the binomial, and permits a 224 // deviation up to max_deviation (measured in standard deviations). 225 inline BinomialMatcher IsBinomialObservation( 226 const int num_samples, const double p, const double max_deviation = Z99_99 ) 227 { 228 return BinomialMatcher( num_samples, p, max_deviation ); 229 } 230 231 #endif // CATA_TESTS_TEST_STATISTICS_H 232