1 /*
2   Copyright 2015-2017 Oliver Heimlich
3   Copyright 2017 Joel Dahne
4 
5   This program is free software; you can redistribute it and/or modify
6   it under the terms of the GNU General Public License as published by
7   the Free Software Foundation; either version 3 of the License, or
8   (at your option) any later version.
9 
10   This program is distributed in the hope that it will be useful,
11   but WITHOUT ANY WARRANTY; without even the implied warranty of
12   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13   GNU General Public License for more details.
14 
15   You should have received a copy of the GNU General Public License
16   along with this program; if not, see <http://www.gnu.org/licenses/>.
17 */
18 
19 #include <octave/oct.h>
20 #include <mpfr.h>
21 #include "mpfr_commons.h"
22 
23 DEFUN_DLD (mpfr_vector_sum_d, args, nargout,
24   "-*- texinfo -*-\n"
25   "@documentencoding UTF-8\n"
26   "@deftypefun {[@var{S}, @var{E}] =} mpfr_vector_sum_d (@var{R}, @var{X}, "
27   "@var{dim})\n\n"
28   "Compute the sum @var{S} of all numbers in a binary64 array @var{X} along "
29   "dimension @var{dim} with correctly rounded result."
30   "\n\n"
31   "@var{R} is the rounding direction (@option{0}: towards zero, @option{0.5}: "
32   "towards nearest and ties to even, @option{+inf}: towards positive "
33   "infinity, @option{-inf}: towards negative infinity)."
34   "\n\n"
35   "The result is guaranteed to be correctly rounded.  That is, the sum is "
36   "evaluated with (virtually) infinite precision and the exact result is "
37   "approximated with a binary64 number using the desired rounding direction."
38   "\n\n"
39   "If one element of the sum is NaN or infinities of both signs are "
40   "encountered, the result will be NaN.  An @emph{exact} zero is returned as "
41   "+0 in all rounding directions, except for rounding towards negative "
42   "infinity, where -0 is returned."
43   "\n\n"
44   "A second output parameter yields an approximation of the error.  The "
45   "difference between the exact sum over @var{X} and @var{S} is approximated "
46   "by a second binary64 number @var{E} with rounding towards zero."
47   "\n\n"
48   "@example\n"
49   "@group\n"
50   "mpfr_vector_sum_d (-inf, [1, eps/2, realmax, -realmax], 2) == 1\n"
51   "  @result{} 1\n"
52   "mpfr_vector_sum_d (+inf, [1, eps/2, realmax, -realmax], 2) == 1 + eps\n"
53   "  @result{} 1\n"
54   "@end group\n"
55   "@end example\n"
56   "@seealso{sum}\n"
57   "@end deftypefun"
58   )
59 {
60   // Check call syntax
61   int nargin = args.length ();
62   if (nargin != 3)
63     {
64       print_usage ();
65       return octave_value_list ();
66     }
67 
68   // Read parameters
69   const mpfr_rnd_t rnd      = parse_rounding_mode (args (0).scalar_value ());
70   const NDArray    array    = args (1).array_value ();
71   const octave_idx_type dim = args (2).scalar_value ();
72   if (error_state)
73     return octave_value_list ();
74 
75   if (dim > array.ndims ())
76     {
77       // Nothing to be done
78       return octave_value (array);
79     }
80 
81   // Determine size for result
82   dim_vector array_dims = array.dims ();
83   // Inconsistency: sum ([]) = 0
84   if (array.ndims () == 2 && array_dims(0) == 0 && array_dims(1) == 0)
85       array_dims(1) = 1;
86   dim_vector array_cdims = array_dims.cumulative ();
87   dim_vector result_dims = array_dims;
88   result_dims (dim - 1) = 1;
89   result_dims.chop_trailing_singletons ();
90   dim_vector result_cdims = result_dims.cumulative ();
91   NDArray result_sum (result_dims);
92   NDArray result_error (result_dims);
93 
94   mpfr_t accu;
95   mpfr_init2 (accu, BINARY64_ACCU_PRECISION);
96 
97   octave_idx_type step;
98   if (dim > 1)
99       step = array_cdims (dim - 2);
100   else
101       step = 1;
102   octave_idx_type n = result_dims.numel ();
103   octave_idx_type m = array_dims (dim - 1) * step;
104   octave_idx_type idx_array;
105   octave_idx_type idx_result;
106 
107   OCTAVE_LOCAL_BUFFER_INIT (octave_idx_type, idx, array.ndims (), 0);
108 
109   for (int i = 0; i < n; i ++)
110   {
111       idx_array = array_cdims.cum_compute_index (idx);
112       idx_result = result_cdims.cum_compute_index (idx);
113       mpfr_set_zero (accu, 0);
114 
115       // Perform the summation
116       for (octave_idx_type j = 0; j < m; j += step)
117       {
118           int exact = mpfr_add_d (accu, accu, array (idx_array + j), rnd);
119           if (exact != 0)
120               error ("mpfr_vector_sum_d: Failed to compute exact sum");
121           if (mpfr_nan_p (accu))
122               // Short-Circtuit if one addend is NAN or if -INF + INF
123               break;
124       }
125 
126       // Check the result
127       if (mpfr_nan_p (accu) != 0)
128       {
129           result_sum.elem (idx_result) = NAN;
130           result_error (idx_result) = NAN;
131       }
132       else
133           if (mpfr_cmp_d (accu, 0.0) == 0)
134           {
135               // exact zero
136               if (rnd == MPFR_RNDD)
137                   result_sum.elem (idx_result) = -0.0;
138               else
139                   result_sum.elem (idx_result) = +0.0;
140               result_error (idx_result) = 0.0;
141           }
142           else
143           {
144               const double sum = mpfr_get_d (accu, rnd);
145               result_sum.elem (idx_result) = sum;
146               if (nargout >= 2)
147               {
148                   mpfr_sub_d (accu, accu, sum, MPFR_RNDA);
149                   const double error = mpfr_get_d (accu, MPFR_RNDA);
150                   result_error.elem (idx_result) = error;
151               }
152           }
153       result_dims.increment_index (idx);
154   }
155 
156   mpfr_clear (accu);
157   octave_value_list result;
158   result (0) = octave_value (result_sum);
159   result (1) = octave_value (result_error);
160   return result;
161 }
162 
163 /*
164 %!assert (mpfr_vector_sum_d (0, [eps, realmax, realmax, -realmax, -realmax], 2), eps)
165 %!assert (mpfr_vector_sum_d (-inf, [eps/2, 1], 2), 1)
166 %!assert (mpfr_vector_sum_d (+inf, [eps/2, 1], 2), 1 + eps)
167 %!test
168 %!  a = inf (infsup ("0X1.1111111111111P+100"));
169 %!  b = inf (infsup ("0X1.1111111111111P+1"));
170 %!  [s, e] = mpfr_vector_sum_d (0.5, [a, b], 2);
171 %!  assert (s, a);
172 %!  assert (e, b);
173 %!test
174 %!  a = inf (infsup ("0X1.1111111111111P+53"));
175 %!  b = inf (infsup ("0X1.1111111111111P+1"));
176 %!  c = inf (infsup ("0X1.1111111111112P+53"));
177 %!  d = inf (infsup ("0X1.111111111111P-3"));
178 %!  [s, e] = mpfr_vector_sum_d (0.5, [a, b], 2);
179 %!  assert (s, c);
180 %!  assert (e, d);
181 %!test
182 %!  a = inf (infsup ("0X1.1111111111111P+2"));
183 %!  b = inf (infsup ("0X1.1111111111111P+1"));
184 %!  c = inf (infsup ("0X1.999999999999AP+2"));
185 %!  d = inf (infsup ("-0X1P-51"));
186 %!  [s, e] = mpfr_vector_sum_d (0.5, [a, b], 2);
187 %!  assert (s, c);
188 %!  assert (e, d);
189 %!test
190 %!  for dim = 1:6
191 %!    assert (mpfr_vector_sum_d (0.5, ones (1, 2, 3, 4, 5), dim), sum (ones (1, 2, 3, 4, 5), dim));
192 %!  endfor
193 
194 %!shared testdata
195 %! testdata = load (file_in_loadpath ("test/itl.mat"));
196 
197 %!test
198 %! # Scalar evaluation
199 %! testcases = testdata.NoSignal.double.sum_nearest;
200 %! for testcase = [testcases]'
201 %!   assert (isequaln (...
202 %!     mpfr_vector_sum_d (0.5, testcase.in{1}, 2), ...
203 %!     testcase.out));
204 %! endfor
205 
206 %!test
207 %! # Vector evaluation
208 %! testcases = testdata.NoSignal.double.sum_nearest;
209 %! in1 = vertcat (testcases.in);
210 %! in1 = cell2mat (cellfun ("postpad", in1, {(max (cellfun ("numel", in1)))}, "UniformOutput", false));
211 %! out = vertcat (testcases.out);
212 %! assert (isequaln (mpfr_vector_sum_d (0.5, in1, 2), out));
213 
214 %!test
215 %! # Scalar evaluation
216 %! testcases = testdata.NoSignal.double.sum_abs_nearest;
217 %! for testcase = [testcases]'
218 %!   assert (isequaln (...
219 %!     mpfr_vector_sum_d (0.5, abs (testcase.in{1}), 2), ...
220 %!     testcase.out));
221 %! endfor
222 
223 %!test
224 %! # Vector evaluation
225 %! testcases = testdata.NoSignal.double.sum_abs_nearest;
226 %! in1 = vertcat (testcases.in);
227 %! in1 = cell2mat (cellfun ("postpad", in1, {(max (cellfun ("numel", in1)))}, "UniformOutput", false));
228 %! out = vertcat (testcases.out);
229 %! assert (isequaln (mpfr_vector_sum_d (0.5, abs (in1), 2), out));
230 
231 */
232