1 /*
2 * Copyright (c) 2016, Facebook, Inc.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree. An additional grant
7 * of patent rights can be found in the PATENTS file in the same directory.
8 */
9
10 #include <fatal/math/statistical_moments.h>
11 #include <fatal/type/scalar.h>
12
13 #include <fatal/test/driver.h>
14
15 #include <random>
16 #include <ratio>
17 #include <vector>
18
19 namespace fatal {
20
21 using value_type = long double;
22
23 // number of iterations for some tests
24 using iterations = size_constant<100>;
25
26 // how many rounds to run on each randomized test
27 using rounds = size_constant<100>;
28
29 // the number of samples for each randomized test
30 using samples = size_constant<100000>;
31
32 // the relative error allowed when comparing results
33 using relative_error = std::ratio<1, 100>;
34
35 // the absolute error to use when comparing results
36 using absolute_error = std::ratio<1, 1000>;
37
38 // the precision to use when comparing results that should be equal
39 using equality_precision = std::ratio<1, 10000000>;
40
41 ////////////////////
42 // TEST UTILITIES //
43 ////////////////////
44
45 template <typename T, typename TRNG, typename TDistribution>
random_samples(std::size_t case_samples,TRNG && rng,TDistribution && distribution)46 std::vector<T> random_samples(
47 std::size_t case_samples,
48 TRNG &&rng,
49 TDistribution &&distribution
50 ) {
51 std::vector<T> result;
52 result.reserve(case_samples);
53
54 while (case_samples--) {
55 result.push_back(distribution(rng));
56 }
57
58 return result;
59 }
60
61 template <typename T, typename TRNG, typename TDistribution>
statistical_moments_test_round(std::size_t const case_samples,TRNG && rng,TDistribution && distribution)62 void statistical_moments_test_round(
63 std::size_t const case_samples,
64 TRNG &&rng,
65 TDistribution &&distribution
66 ) {
67 using case_value_type = T;
68
69 FATAL_ASSERT_GT(case_samples, 0);
70
71 auto const v1 = random_samples<case_value_type>(
72 case_samples / 2, rng, distribution
73 );
74 auto const v2 = random_samples<case_value_type>(
75 case_samples / 2 + (case_samples & 1),
76 rng,
77 distribution
78 );
79 FATAL_ASSERT_EQ(case_samples, v1.size() + v2.size());
80
81 statistical_moments<case_value_type> moments; // all samples
82 statistical_moments<case_value_type> moments1; // samples in v1
83 statistical_moments<case_value_type> moments2; // samples in v2
84
85 FATAL_EXPECT_TRUE(moments.empty());
86 FATAL_EXPECT_TRUE(moments1.empty());
87 FATAL_EXPECT_TRUE(moments2.empty());
88
89 // calculates the moments using the streaming algorithm
90
91 case_value_type sum = 0;
92
93 for (auto i: v1) {
94 sum += i;
95 moments.add(i);
96 moments1.add(i);
97 }
98
99 for (auto i: v2) {
100 sum += i;
101 moments.add(i);
102 moments2.add(i);
103 }
104
105 FATAL_EXPECT_FALSE(moments.empty());
106 FATAL_EXPECT_FALSE(moments1.empty());
107 FATAL_EXPECT_FALSE(moments2.empty());
108
109 FATAL_EXPECT_EQ(case_samples, moments.size());
110 FATAL_EXPECT_EQ(v1.size(), moments1.size());
111 FATAL_EXPECT_EQ(v2.size(), moments2.size());
112
113 // calculates the moments using the traditional approach, for later comparison
114
115 auto const mean = sum / case_samples;
116
117 case_value_type cumulant_1 = 0;
118 case_value_type cumulant_3 = 0;
119 case_value_type cumulant_4 = 0;
120
121 for (auto i: v1) {
122 auto const x = i - mean;
123 cumulant_1 += x * x;
124 cumulant_3 += x * x * x;
125 cumulant_4 += x * x * x * x;
126 }
127
128 for (auto i: v2) {
129 auto const x = i - mean;
130 cumulant_1 += x * x;
131 cumulant_3 += x * x * x;
132 cumulant_4 += x * x * x * x;
133 }
134
135 auto const variance = cumulant_1 / case_samples;
136 auto const standard_deviation = std::sqrt(variance);
137 auto const standard_deviation_3 = variance * standard_deviation;
138 auto const kurtosis = cumulant_4 / case_samples / (variance * variance) - 3;
139 auto const skewness = cumulant_3 / case_samples / standard_deviation_3;
140
141 // checks the results
142
143 // checks if results are different by no more than precision
144 # define TEST_ABSOLUTE_ERROR_IMPL( \
145 expected, \
146 actual, \
147 relative_error, \
148 absolute_error \
149 ) \
150 do { \
151 FATAL_ASSERT_GE(absolute_error, 0); \
152 FATAL_EXPECT_LE(expected - absolute_error, actual); \
153 FATAL_EXPECT_GE(expected + absolute_error, actual); \
154 } while (false)
155
156 // checks if results are different by no more than threshold percent,
157 // or precision if such percentage of the expected result is smaller
158 // than the precision
159 # define TEST_RELATIVE_ERROR_IMPL( \
160 expected, \
161 actual, \
162 relative_error, \
163 absolute_error \
164 ) \
165 do { \
166 auto const margin = std::abs(expected * relative_error); \
167 \
168 if (margin < absolute_error) { \
169 TEST_ABSOLUTE_ERROR_IMPL( \
170 expected, \
171 actual, \
172 relative_error, \
173 absolute_error \
174 ); \
175 } else { \
176 FATAL_EXPECT_LE(expected - margin, actual); \
177 FATAL_EXPECT_GE(expected + margin, actual); \
178 } \
179 } while (false)
180
181 # define TEST_ALL_IMPL(Fn, ...) \
182 do { \
183 Fn(TEST_RELATIVE_ERROR_IMPL, mean, __VA_ARGS__); \
184 Fn(TEST_RELATIVE_ERROR_IMPL, variance, __VA_ARGS__); \
185 Fn(TEST_RELATIVE_ERROR_IMPL, standard_deviation, __VA_ARGS__); \
186 Fn(TEST_ABSOLUTE_ERROR_IMPL, skewness, __VA_ARGS__); \
187 Fn(TEST_ABSOLUTE_ERROR_IMPL, kurtosis, __VA_ARGS__); \
188 } while (false)
189
190 # define TEST_IMPL(Fn, what, subject) \
191 Fn( \
192 what, \
193 subject.what(), \
194 (to_scalar<relative_error, long double>()), \
195 (to_scalar<absolute_error, long double>()) \
196 )
197
198 // compares `moments` against the expected results
199 TEST_ALL_IMPL(TEST_IMPL, moments);
200
201 # undef TEST_IMPL
202 # define TEST_IMPL(Fn, what, expected, actual) \
203 Fn( \
204 expected.what(), \
205 actual.what(), \
206 (to_scalar<relative_error, long double>()), \
207 (to_scalar<equality_precision, long double>()) \
208 )
209
210 // tests merging into an empty instance
211 statistical_moments<case_value_type> merge_into_empty;
212 merge_into_empty.merge(moments);
213 TEST_ALL_IMPL(TEST_IMPL, moments, merge_into_empty);
214
215 // tests merging two sets of samples
216 auto merged_subsets(moments1);
217 merged_subsets.merge(moments2);
218 TEST_ALL_IMPL(TEST_IMPL, moments, merged_subsets);
219
220 # undef TEST_IMPL
221 # undef TEST_ALL_IMPL
222 }
223
224 template <typename T, typename TRNG, typename TDistribution>
test_statistical_moments(std::size_t const case_samples,TRNG && rng,TDistribution && distribution)225 void test_statistical_moments(
226 std::size_t const case_samples,
227 TRNG &&rng,
228 TDistribution &&distribution
229 ) {
230 for (auto i = rounds::value; i--; ) {
231 statistical_moments_test_round<T>(case_samples, rng, distribution);
232 }
233 }
234
235 ///////////
236 // TESTS //
237 ///////////
238
239 // the minimum value of a sample on uniform distribution randomized tests
240 using uniform_min = std::ratio<-1000>;
241
242 // the maximum value of a sample on uniform distribution randomized tests
243 using uniform_max = std::ratio<1000>;
244
245 static_assert(std::ratio_less_equal<uniform_min, uniform_max>::value, "");
246
FATAL_TEST(statistical_moments,uniform_distribution)247 FATAL_TEST(statistical_moments, uniform_distribution) {
248 test_statistical_moments<value_type>(
249 samples::value,
250 random_data(),
251 std::uniform_real_distribution<value_type>(
252 to_scalar<uniform_min, long double>(),
253 to_scalar<uniform_max, long double>()
254 )
255 );
256 }
257
258 // the mean to use for normal distribution randomized tests
259 using normal_mean = std::ratio<0>;
260
261 // the standard deviation to use for normal distribution randomized tests
262 using normal_stddev = std::ratio<10>;
263
FATAL_TEST(statistical_moments,normal_distribution)264 FATAL_TEST(statistical_moments, normal_distribution) {
265 test_statistical_moments<value_type>(
266 samples::value,
267 random_data(),
268 std::normal_distribution<value_type>(
269 to_scalar<normal_mean, long double>(),
270 to_scalar<normal_stddev, long double>()
271 )
272 );
273 }
274
275 // the rate (lambda) to use for exponential distribution randomized tests
276 using exponential_lambda = std::ratio<10>;
277
FATAL_TEST(statistical_moments,exponential_distribution)278 FATAL_TEST(statistical_moments, exponential_distribution) {
279 test_statistical_moments<value_type>(
280 samples::value,
281 random_data(),
282 std::exponential_distribution<value_type>(
283 to_scalar<exponential_lambda, long double>()
284 )
285 );
286 }
287
FATAL_TEST(statistical_moments,state)288 FATAL_TEST(statistical_moments, state) {
289 random_data rng;
290 std::normal_distribution<value_type> distribution(
291 to_scalar<normal_mean, long double>(),
292 to_scalar<normal_stddev, long double>()
293 );
294
295 statistical_moments<value_type> moments;
296 statistical_moments<value_type> empty_copy(moments.state());
297 FATAL_EXPECT_EQ(moments, empty_copy);
298
299 for (auto i = iterations::value; i--; ) {
300 moments.add(distribution(rng));
301 statistical_moments<value_type> copy(moments.state());
302 FATAL_EXPECT_EQ(moments, copy);
303 }
304 }
305
306 } // namespace fatal {
307