1 /* 2 Copyright 2015-2017 Oliver Heimlich 3 Copyright 2017 Joel Dahne 4 5 This program is free software; you can redistribute it and/or modify 6 it under the terms of the GNU General Public License as published by 7 the Free Software Foundation; either version 3 of the License, or 8 (at your option) any later version. 9 10 This program is distributed in the hope that it will be useful, 11 but WITHOUT ANY WARRANTY; without even the implied warranty of 12 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 GNU General Public License for more details. 14 15 You should have received a copy of the GNU General Public License 16 along with this program; if not, see <http://www.gnu.org/licenses/>. 17 */ 18 19 #include <octave/oct.h> 20 #include <mpfr.h> 21 #include "mpfr_commons.h" 22 23 DEFUN_DLD (mpfr_vector_sum_d, args, nargout, 24 "-*- texinfo -*-\n" 25 "@documentencoding UTF-8\n" 26 "@deftypefun {[@var{S}, @var{E}] =} mpfr_vector_sum_d (@var{R}, @var{X}, " 27 "@var{dim})\n\n" 28 "Compute the sum @var{S} of all numbers in a binary64 array @var{X} along " 29 "dimension @var{dim} with correctly rounded result." 30 "\n\n" 31 "@var{R} is the rounding direction (@option{0}: towards zero, @option{0.5}: " 32 "towards nearest and ties to even, @option{+inf}: towards positive " 33 "infinity, @option{-inf}: towards negative infinity)." 34 "\n\n" 35 "The result is guaranteed to be correctly rounded. That is, the sum is " 36 "evaluated with (virtually) infinite precision and the exact result is " 37 "approximated with a binary64 number using the desired rounding direction." 38 "\n\n" 39 "If one element of the sum is NaN or infinities of both signs are " 40 "encountered, the result will be NaN. An @emph{exact} zero is returned as " 41 "+0 in all rounding directions, except for rounding towards negative " 42 "infinity, where -0 is returned." 43 "\n\n" 44 "A second output parameter yields an approximation of the error. The " 45 "difference between the exact sum over @var{X} and @var{S} is approximated " 46 "by a second binary64 number @var{E} with rounding towards zero." 47 "\n\n" 48 "@example\n" 49 "@group\n" 50 "mpfr_vector_sum_d (-inf, [1, eps/2, realmax, -realmax], 2) == 1\n" 51 " @result{} 1\n" 52 "mpfr_vector_sum_d (+inf, [1, eps/2, realmax, -realmax], 2) == 1 + eps\n" 53 " @result{} 1\n" 54 "@end group\n" 55 "@end example\n" 56 "@seealso{sum}\n" 57 "@end deftypefun" 58 ) 59 { 60 // Check call syntax 61 int nargin = args.length (); 62 if (nargin != 3) 63 { 64 print_usage (); 65 return octave_value_list (); 66 } 67 68 // Read parameters 69 const mpfr_rnd_t rnd = parse_rounding_mode (args (0).scalar_value ()); 70 const NDArray array = args (1).array_value (); 71 const octave_idx_type dim = args (2).scalar_value (); 72 if (error_state) 73 return octave_value_list (); 74 75 if (dim > array.ndims ()) 76 { 77 // Nothing to be done 78 return octave_value (array); 79 } 80 81 // Determine size for result 82 dim_vector array_dims = array.dims (); 83 // Inconsistency: sum ([]) = 0 84 if (array.ndims () == 2 && array_dims(0) == 0 && array_dims(1) == 0) 85 array_dims(1) = 1; 86 dim_vector array_cdims = array_dims.cumulative (); 87 dim_vector result_dims = array_dims; 88 result_dims (dim - 1) = 1; 89 result_dims.chop_trailing_singletons (); 90 dim_vector result_cdims = result_dims.cumulative (); 91 NDArray result_sum (result_dims); 92 NDArray result_error (result_dims); 93 94 mpfr_t accu; 95 mpfr_init2 (accu, BINARY64_ACCU_PRECISION); 96 97 octave_idx_type step; 98 if (dim > 1) 99 step = array_cdims (dim - 2); 100 else 101 step = 1; 102 octave_idx_type n = result_dims.numel (); 103 octave_idx_type m = array_dims (dim - 1) * step; 104 octave_idx_type idx_array; 105 octave_idx_type idx_result; 106 107 OCTAVE_LOCAL_BUFFER_INIT (octave_idx_type, idx, array.ndims (), 0); 108 109 for (int i = 0; i < n; i ++) 110 { 111 idx_array = array_cdims.cum_compute_index (idx); 112 idx_result = result_cdims.cum_compute_index (idx); 113 mpfr_set_zero (accu, 0); 114 115 // Perform the summation 116 for (octave_idx_type j = 0; j < m; j += step) 117 { 118 int exact = mpfr_add_d (accu, accu, array (idx_array + j), rnd); 119 if (exact != 0) 120 error ("mpfr_vector_sum_d: Failed to compute exact sum"); 121 if (mpfr_nan_p (accu)) 122 // Short-Circtuit if one addend is NAN or if -INF + INF 123 break; 124 } 125 126 // Check the result 127 if (mpfr_nan_p (accu) != 0) 128 { 129 result_sum.elem (idx_result) = NAN; 130 result_error (idx_result) = NAN; 131 } 132 else 133 if (mpfr_cmp_d (accu, 0.0) == 0) 134 { 135 // exact zero 136 if (rnd == MPFR_RNDD) 137 result_sum.elem (idx_result) = -0.0; 138 else 139 result_sum.elem (idx_result) = +0.0; 140 result_error (idx_result) = 0.0; 141 } 142 else 143 { 144 const double sum = mpfr_get_d (accu, rnd); 145 result_sum.elem (idx_result) = sum; 146 if (nargout >= 2) 147 { 148 mpfr_sub_d (accu, accu, sum, MPFR_RNDA); 149 const double error = mpfr_get_d (accu, MPFR_RNDA); 150 result_error.elem (idx_result) = error; 151 } 152 } 153 result_dims.increment_index (idx); 154 } 155 156 mpfr_clear (accu); 157 octave_value_list result; 158 result (0) = octave_value (result_sum); 159 result (1) = octave_value (result_error); 160 return result; 161 } 162 163 /* 164 %!assert (mpfr_vector_sum_d (0, [eps, realmax, realmax, -realmax, -realmax], 2), eps) 165 %!assert (mpfr_vector_sum_d (-inf, [eps/2, 1], 2), 1) 166 %!assert (mpfr_vector_sum_d (+inf, [eps/2, 1], 2), 1 + eps) 167 %!test 168 %! a = inf (infsup ("0X1.1111111111111P+100")); 169 %! b = inf (infsup ("0X1.1111111111111P+1")); 170 %! [s, e] = mpfr_vector_sum_d (0.5, [a, b], 2); 171 %! assert (s, a); 172 %! assert (e, b); 173 %!test 174 %! a = inf (infsup ("0X1.1111111111111P+53")); 175 %! b = inf (infsup ("0X1.1111111111111P+1")); 176 %! c = inf (infsup ("0X1.1111111111112P+53")); 177 %! d = inf (infsup ("0X1.111111111111P-3")); 178 %! [s, e] = mpfr_vector_sum_d (0.5, [a, b], 2); 179 %! assert (s, c); 180 %! assert (e, d); 181 %!test 182 %! a = inf (infsup ("0X1.1111111111111P+2")); 183 %! b = inf (infsup ("0X1.1111111111111P+1")); 184 %! c = inf (infsup ("0X1.999999999999AP+2")); 185 %! d = inf (infsup ("-0X1P-51")); 186 %! [s, e] = mpfr_vector_sum_d (0.5, [a, b], 2); 187 %! assert (s, c); 188 %! assert (e, d); 189 %!test 190 %! for dim = 1:6 191 %! assert (mpfr_vector_sum_d (0.5, ones (1, 2, 3, 4, 5), dim), sum (ones (1, 2, 3, 4, 5), dim)); 192 %! endfor 193 194 %!shared testdata 195 %! testdata = load (file_in_loadpath ("test/itl.mat")); 196 197 %!test 198 %! # Scalar evaluation 199 %! testcases = testdata.NoSignal.double.sum_nearest; 200 %! for testcase = [testcases]' 201 %! assert (isequaln (... 202 %! mpfr_vector_sum_d (0.5, testcase.in{1}, 2), ... 203 %! testcase.out)); 204 %! endfor 205 206 %!test 207 %! # Vector evaluation 208 %! testcases = testdata.NoSignal.double.sum_nearest; 209 %! in1 = vertcat (testcases.in); 210 %! in1 = cell2mat (cellfun ("postpad", in1, {(max (cellfun ("numel", in1)))}, "UniformOutput", false)); 211 %! out = vertcat (testcases.out); 212 %! assert (isequaln (mpfr_vector_sum_d (0.5, in1, 2), out)); 213 214 %!test 215 %! # Scalar evaluation 216 %! testcases = testdata.NoSignal.double.sum_abs_nearest; 217 %! for testcase = [testcases]' 218 %! assert (isequaln (... 219 %! mpfr_vector_sum_d (0.5, abs (testcase.in{1}), 2), ... 220 %! testcase.out)); 221 %! endfor 222 223 %!test 224 %! # Vector evaluation 225 %! testcases = testdata.NoSignal.double.sum_abs_nearest; 226 %! in1 = vertcat (testcases.in); 227 %! in1 = cell2mat (cellfun ("postpad", in1, {(max (cellfun ("numel", in1)))}, "UniformOutput", false)); 228 %! out = vertcat (testcases.out); 229 %! assert (isequaln (mpfr_vector_sum_d (0.5, abs (in1), 2), out)); 230 231 */ 232