1 /*
2   Copyright 2015-2016 Oliver Heimlich
3   Copyright 2017 Joel Dahne
4 
5   This program is free software; you can redistribute it and/or modify
6   it under the terms of the GNU General Public License as published by
7   the Free Software Foundation; either version 3 of the License, or
8   (at your option) any later version.
9 
10   This program is distributed in the hope that it will be useful,
11   but WITHOUT ANY WARRANTY; without even the implied warranty of
12   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13   GNU General Public License for more details.
14 
15   You should have received a copy of the GNU General Public License
16   along with this program; if not, see <http://www.gnu.org/licenses/>.
17 */
18 
19 #include <octave/oct.h>
20 #include <mpfr.h>
21 #include "mpfr_commons.h"
22 
interval_vector_dot(NDArray array_xl,NDArray array_yl,NDArray array_xu,NDArray array_yu,octave_idx_type dim)23 std::pair <NDArray, NDArray> interval_vector_dot (
24   NDArray array_xl, NDArray array_yl,
25   NDArray array_xu, NDArray array_yu,
26   octave_idx_type dim)
27 {
28   int dimensions = std::max (array_xl.ndims (), array_yl.ndims ());
29   if (dim > dimensions)
30     {
31       dimensions += 1;
32       dim = dimensions;
33     }
34   dim_vector x_dims = array_xl.dims().redim (dimensions);
35   dim_vector y_dims = array_yl.dims().redim (dimensions);
36   // Inconsistency: dot ([], []) = 0
37   if (array_xl.ndims () == 2 && x_dims(0) == 0 && x_dims(1) == 0 &&
38       array_yl.ndims () == 2 && y_dims(0) == 0 && y_dims(1) == 0)
39     {
40       x_dims(1) = 1;
41       y_dims(1) = 1;
42     }
43   dim_vector x_cdims = x_dims.cumulative ();
44   dim_vector y_cdims = y_dims.cumulative ();
45 
46   // Check if broadcasting can be performed
47   for (int d = 0; d < dimensions; d ++)
48     {
49       if (x_dims (d) != 1 && y_dims (d) != 1 &&
50           x_dims (d) != y_dims (d))
51         error ("mpfr_function_d: Array dimensions must agree!");
52     }
53 
54   // Create result array of right size
55   dim_vector result_dims;
56   result_dims.resize (dimensions);
57 
58   for (int i = 0; i < dimensions; i ++)
59     {
60       if (x_dims(i) != 1)
61         result_dims(i) = x_dims(i);
62       else
63         result_dims(i) = y_dims(i);
64     }
65 
66   result_dims(dim - 1) = 1;
67 
68   std::pair <NDArray, NDArray> result;
69   result.first = NDArray (result_dims);
70   result.second = NDArray (result_dims);
71 
72   // Find increment for elements along dimension dim
73   octave_idx_type x_idx_increment;
74   octave_idx_type y_idx_increment;
75 
76   if (x_dims (dim - 1) == 1)
77     x_idx_increment = 0;
78   else if (dim == 1)
79     x_idx_increment = 1;
80   else
81     x_idx_increment = x_cdims(dim - 2);
82 
83   if (y_dims (dim - 1) == 1)
84     y_idx_increment = 0;
85   else if (dim == 1)
86     y_idx_increment = 1;
87   else
88     y_idx_increment = y_cdims(dim - 2);
89 
90   // Fix broadcasting along all singleton dimensions
91   // Does not work for broadcasting along the first dimension
92   for (int i = 1; i < dimensions; i ++)
93     {
94       if (x_dims(i) == 1)
95         x_cdims(i-1) = 0;
96       if (y_dims(i) == 1)
97         y_cdims(i-1) = 0;
98     }
99 
100   // Check if broadcasting along the first dimension is needed
101   bool broadcast_first = (x_dims(0) == 1 && y_dims(0) != 1)
102       || (x_dims(0) != 1 && y_dims(0) == 1);
103 
104   // Accumulators
105   mpfr_t accu_l, accu_u, mp_addend_l, mp_addend_u, mp_temp;
106   mpfr_init2 (accu_l, BINARY64_ACCU_PRECISION);
107   mpfr_init2 (accu_u, BINARY64_ACCU_PRECISION);
108   mpfr_init2 (mp_addend_l, 2 * BINARY64_PRECISION + 1);
109   mpfr_init2 (mp_addend_u, 2 * BINARY64_PRECISION + 1);
110   mpfr_init2 (mp_temp,     2 * BINARY64_PRECISION + 1);
111 
112   // Loop over all elements in the result
113   octave_idx_type x_idx;
114   octave_idx_type y_idx;
115   OCTAVE_LOCAL_BUFFER_INIT (octave_idx_type, idx, dimensions, 0);
116 
117   octave_idx_type n = result.first.numel ();
118   octave_idx_type m;
119   if (x_dims(dim - 1) != 1)
120       m = x_dims(dim - 1);
121   else
122       m = y_dims(dim - 1);
123 
124   for (octave_idx_type i = 0; i < n; i ++)
125     {
126       // Take broadcasting into account
127       x_idx = x_cdims.cum_compute_index (idx);
128       y_idx = y_cdims.cum_compute_index (idx);
129 
130       // Broadcasting along the first dimension needs to be handled
131       // separately
132       if (broadcast_first)
133         {
134           if (x_dims(0) == 1)
135               x_idx -= idx[0];
136           else
137               y_idx -= idx[0];
138         }
139 
140       mpfr_set_zero (accu_l, 0);
141       mpfr_set_zero (accu_u, 0);
142 
143       // Compute result for current element
144       for (octave_idx_type j = 0; j < m; j ++)
145         {
146           const double xl = array_xl.elem (x_idx + x_idx_increment*j);
147           const double xu = array_xu.elem (x_idx + x_idx_increment*j);
148           const double yl = array_yl.elem (y_idx + y_idx_increment*j);
149           const double yu = array_yu.elem (y_idx + y_idx_increment*j);
150 
151           if ((xl == INFINITY && xu == -INFINITY)
152               ||
153               (yl == INFINITY && yu == -INFINITY))
154             {
155               // [Empty] × Anything = [Empty]
156               // [Empty] + Anything = [Empty]
157               mpfr_set_inf (accu_l, +1);
158               mpfr_set_inf (accu_u, -1);
159               break;
160             }
161 
162           if (mpfr_inf_p (accu_l) != 0 && mpfr_inf_p (accu_u) != 0)
163             // [Entire] + Anything = [Entire]
164             continue;
165 
166           if ((xl == 0.0 && xu == 0.0)
167               ||
168               (yl == 0.0 && yu == 0.0))
169             // [0] × Anything = [0]
170             continue;
171 
172           if ((xl == -INFINITY && xu == INFINITY)
173               ||
174               (yl == -INFINITY && yu == INFINITY))
175             {
176               // [Entire] × Anything = [Entire]
177               mpfr_set_inf (accu_l, -1);
178               mpfr_set_inf (accu_u, +1);
179               continue;
180             }
181 
182           // Both factors can be multiplied within 107 bits exactly!
183           mpfr_set_d (mp_addend_l, xl, MPFR_RNDZ);
184           mpfr_mul_d (mp_addend_l, mp_addend_l, yl, MPFR_RNDZ);
185           mpfr_set (mp_addend_u, mp_addend_l, MPFR_RNDZ);
186 
187           // We have to compute the remaining 3 Products and determine min/max
188           if (yl != yu)
189             {
190               mpfr_set_d (mp_temp, xl, MPFR_RNDZ);
191               mpfr_mul_d (mp_temp, mp_temp, yu, MPFR_RNDZ);
192               mpfr_min (mp_addend_l, mp_addend_l, mp_temp, MPFR_RNDZ);
193               mpfr_max (mp_addend_u, mp_addend_u, mp_temp, MPFR_RNDZ);
194             }
195           if (xl != xu)
196             {
197               mpfr_set_d (mp_temp, xu, MPFR_RNDZ);
198               mpfr_mul_d (mp_temp, mp_temp, yl, MPFR_RNDZ);
199               mpfr_min (mp_addend_l, mp_addend_l, mp_temp, MPFR_RNDZ);
200               mpfr_max (mp_addend_u, mp_addend_u, mp_temp, MPFR_RNDZ);
201             }
202           if (xl != xu || yl != yu)
203             {
204               mpfr_set_d (mp_temp, xu, MPFR_RNDZ);
205               mpfr_mul_d (mp_temp, mp_temp, yu, MPFR_RNDZ);
206               mpfr_min (mp_addend_l, mp_addend_l, mp_temp, MPFR_RNDZ);
207               mpfr_max (mp_addend_u, mp_addend_u, mp_temp, MPFR_RNDZ);
208             }
209 
210           // Compute sums
211           if (mpfr_add (accu_l, accu_l, mp_addend_l, MPFR_RNDZ) != 0 ||
212               mpfr_add (accu_u, accu_u, mp_addend_u, MPFR_RNDZ) != 0)
213             error ("failed to compute exact dot product");
214       }
215       result.first(i) = mpfr_get_d (accu_l, MPFR_RNDD);
216       result.second(i) = mpfr_get_d (accu_u, MPFR_RNDU);
217 
218       result_dims.increment_index (idx);
219     }
220   mpfr_clear (accu_l);
221   mpfr_clear (accu_u);
222   mpfr_clear (mp_addend_l);
223   mpfr_clear (mp_addend_u);
224   mpfr_clear (mp_temp);
225 
226   return result;
227 }
228 
vector_dot(mpfr_rnd_t rnd,NDArray array_x,NDArray array_y,octave_idx_type dim,const bool compute_error)229 std::pair <NDArray, NDArray> vector_dot (
230   mpfr_rnd_t rnd,
231   NDArray array_x, NDArray array_y,
232   octave_idx_type dim,
233   const bool compute_error)
234 {
235     int dimensions = std::max (array_x.ndims (), array_y.ndims ());
236     if (dim > dimensions)
237       {
238         dimensions += 1;
239         dim = dimensions;
240       }
241     dim_vector x_dims = array_x.dims().redim (dimensions);
242     dim_vector y_dims = array_y.dims().redim (dimensions);
243     // Inconsistency: dot ([], []) = 0
244     if (array_x.ndims () == 2 && x_dims(0) == 0 && x_dims(1) == 0 &&
245         array_y.ndims () == 2 && y_dims(0) == 0 && y_dims(1) == 0)
246       {
247         x_dims(1) = 1;
248         y_dims(1) = 1;
249       }
250     dim_vector x_cdims = x_dims.cumulative ();
251     dim_vector y_cdims = y_dims.cumulative ();
252 
253     // Check if broadcasting can be performed
254     for (int d = 0; d < dimensions; d ++)
255       {
256         if (x_dims (d) != 1 && y_dims (d) != 1 &&
257             x_dims (d) != y_dims (d))
258           error ("mpfr_function_d: Array dimensions must agree!");
259       }
260 
261     // Create result array of right size
262     dim_vector result_dims;
263     result_dims.resize (dimensions);
264 
265     for (int i = 0; i < dimensions; i ++)
266       {
267         if (x_dims(i) != 1)
268           result_dims(i) = x_dims(i);
269         else
270           result_dims(i) = y_dims(i);
271       }
272 
273     result_dims(dim - 1) = 1;
274 
275     std::pair <NDArray, NDArray> result_and_error;
276     result_and_error.first = NDArray (result_dims);
277     result_and_error.second = NDArray (result_dims);
278 
279     // Find increment for elements along dimension dim
280     octave_idx_type x_idx_increment;
281     octave_idx_type y_idx_increment;
282 
283     if (x_dims (dim - 1) == 1)
284       x_idx_increment = 0;
285     else if (dim == 1)
286       x_idx_increment = 1;
287     else
288       x_idx_increment = x_cdims(dim - 2);
289 
290     if (y_dims (dim - 1) == 1)
291       y_idx_increment = 0;
292     else if (dim == 1)
293       y_idx_increment = 1;
294     else
295       y_idx_increment = y_cdims(dim - 2);
296 
297     // Fix broadcasting along all singleton dimensions
298     // Does not work for broadcasting along the first dimension
299     for (int i = 1; i < dimensions; i ++)
300       {
301         if (x_dims(i) == 1)
302           x_cdims(i-1) = 0;
303         if (y_dims(i) == 1)
304           y_cdims(i-1) = 0;
305       }
306 
307     // Check if broadcasting along the first dimension is needed
308     bool broadcast_first = (x_dims(0) == 1 && y_dims(0) != 1)
309         || (x_dims(0) != 1 && y_dims(0) == 1);
310 
311     // Accumulators
312     mpfr_t accu, product;
313     mpfr_init2 (accu, BINARY64_ACCU_PRECISION);
314     mpfr_init2 (product, 2 * BINARY64_PRECISION + 1);
315 
316     // Loop over all elements in the result
317     octave_idx_type x_idx;
318     octave_idx_type y_idx;
319     OCTAVE_LOCAL_BUFFER_INIT (octave_idx_type, idx, dimensions, 0);
320 
321     octave_idx_type n = result_and_error.first.numel ();
322     octave_idx_type m;
323     if (x_dims(dim - 1) != 1)
324         m = x_dims(dim - 1);
325     else
326         m = y_dims(dim - 1);
327 
328     for (octave_idx_type i = 0; i < n; i ++)
329       {
330         // Take broadcasting into account
331         x_idx = x_cdims.cum_compute_index (idx);
332         y_idx = y_cdims.cum_compute_index (idx);
333 
334         // Broadcasting along the first dimension needs to be handled
335         // separately
336         if (broadcast_first)
337           {
338             if (x_dims(0) == 1)
339               x_idx -= idx[0];
340             else
341               y_idx -= idx[0];
342           }
343 
344         mpfr_set_zero (accu, 0);
345         // Compute result for element i
346         for (octave_idx_type j = 0; j < m; j ++)
347           {
348             mpfr_set_d (product, array_x.elem (x_idx + x_idx_increment*j),
349                         MPFR_RNDZ);
350             mpfr_mul_d (product, product,
351                         array_y.elem (y_idx + y_idx_increment*j), MPFR_RNDZ);
352 
353             int exact = mpfr_add (accu, accu, product, MPFR_RNDZ);
354             if (exact != 0)
355               error ("mpfr_vector_dot_d: Failed to compute exact dot product");
356             if (mpfr_nan_p (accu))
357               // Short-Circtuit if one addend is NAN or if -INF + INF
358               break;
359           }
360         double result;
361         double error;
362         if (mpfr_nan_p (accu) != 0)
363           {
364             result = NAN;
365             error = NAN;
366           }
367         else
368           {
369             if (mpfr_cmp_d (accu, 0.0) == 0)
370               {
371                 // exact zero
372                 if (rnd == MPFR_RNDD)
373                   result = -0.0;
374                 else
375                   result = +0.0;
376                 error = 0.0;
377               }
378             else
379               {
380                 result = mpfr_get_d (accu, rnd);
381                 if (compute_error)
382                   {
383                     mpfr_sub_d (accu, accu, result, MPFR_RNDA);
384                     error = mpfr_get_d (accu, MPFR_RNDA);
385                   }
386                 else
387                   error = 0.0;
388               }
389           }
390         result_and_error.first(i) = result;
391         result_and_error.second(i) = error;
392 
393         result_dims.increment_index (idx);
394     }
395     mpfr_clear (accu);
396     mpfr_clear (product);
397 
398     return result_and_error;
399 }
400 
401 DEFUN_DLD (mpfr_vector_dot_d, args, nargout,
402   "-*- texinfo -*-\n"
403   "@documentencoding UTF-8\n"
404   "@deftypefun  {[@var{L}, @var{U}] =} mpfr_vector_dot_d (@var{XL}, @var{YL}, @var{XU}, @var{YU}, @var{DIM})\n"
405   "@deftypefunx {[@var{D}, @var{E}] =} mpfr_vector_dot_d (@var{R}, @var{X}, @var{Y}, @var{dim})\n"
406   "\n"
407   "Compute the dot product of arrays of binary 64 numbers along dimension"
408   "@var{DIM} with correctly rounded result."
409   "\n\n"
410   "Syntax 1: Compute the lower and upper boundary of the dot product of "
411   "interval arrays [@var{XL}, @var{XU}] and [@var{YL}, @var{YU}] with "
412   "tightest accuracy."
413   "\n\n"
414   "Syntax 2: Compute the dot product @var{D} of two binary64 arrays with "
415   "correctly rounded result and rounding direction @var{R} (@option{0}: "
416   "towards zero, @option{0.5}: towards nearest and ties to even, "
417   "@option{+inf}: towards positive infinity, @option{-inf}: towards negative "
418   "infinity).  Output parameter @var{E} yields an approximation of the error, "
419   "that is the difference between the exact dot product and @var{D} as a "
420   "second binary64 number, rounded towards zero."
421   "\n\n"
422   "The result is guaranteed to be tight / correctly rounded.  That is, the "
423   "dot product is evaluated with (virtually) infinite precision and the exact "
424   "result is approximated with a binary64 number using the desired rounding "
425   "direction."
426   "\n\n"
427   "For syntax 2 only: If one element of a dot product is NaN, infinities of "
428   "both signs or a product of zero and (positive or negative) infinity are "
429   "encountered, the result will be NaN.  An @emph{exact} zero is returned as "
430   "+0 in all rounding directions, except for rounding towards negative "
431   "infinity, where -0 is returned."
432   "\n\n"
433   "@example\n"
434   "@group\n"
435   "[l, u] = mpfr_vector_dot_d (-1, -1, 2, 3, 1)\n"
436   "  @result{}\n"
437   "    l = -3\n"
438   "    u = 6\n"
439   "@end group\n"
440   "@end example\n"
441   "@seealso{dot}\n"
442   "@end deftypefun"
443   )
444 {
445   // Check call syntax
446   int nargin = args.length ();
447   if (nargin < 4 || nargin > 5)
448     {
449       print_usage ();
450       return octave_value_list ();
451     }
452 
453   octave_value_list result;
454   switch (nargin)
455     {
456       case 5: // Interval version
457         {
458           NDArray array_xl = args (0).array_value ();
459           NDArray array_yl = args (1).array_value ();
460           NDArray array_xu = args (2).array_value ();
461           NDArray array_yu = args (3).array_value ();
462           octave_idx_type dim = args (4).scalar_value ();
463           if (error_state)
464             return octave_value_list ();
465 
466           std::pair <NDArray, NDArray> result_d =
467               interval_vector_dot (array_xl, array_yl, array_xu, array_yu, dim);
468           result (0) = result_d.first;
469           result (1) = result_d.second;
470           break;
471         }
472       case 4: // Non-interval version
473         {
474           const mpfr_rnd_t rnd = parse_rounding_mode (args (0).scalar_value());
475           const NDArray array_x = args (1).array_value ();
476           const NDArray array_y = args (2).array_value ();
477           const octave_idx_type dim = args (3).scalar_value ();
478           if (error_state)
479             return octave_value_list ();
480 
481           std::pair <NDArray, NDArray> result_and_error
482               = vector_dot (rnd, array_x, array_y, dim, nargout >= 2);
483           result (0) = result_and_error.first;
484           result (1) = result_and_error.second;
485           break;
486         }
487     }
488 
489   return result;
490 }
491 
492 /*
493 %!test;
494 %!  [l, u] = mpfr_vector_dot_d (-1, -1, 2, 3, 1);
495 %!  assert (l, -3);
496 %!  assert (u, 6);
497 %!test;
498 %!  x = [realmax, realmax, -realmax, -realmax, 1, eps/2];
499 %!  y = ones (size (x));
500 %!  [l, u] = mpfr_vector_dot_d (x, y, x, y, 2);
501 %!  d = mpfr_vector_dot_d (0.5, x, y, 2);
502 %!  assert (l, 1);
503 %!  assert (u, 1 + eps);
504 %!  assert (ismember (d, infsup (l, u)));
505 %!test;
506 %!  [l, u] = mpfr_vector_dot_d (0, 0, inf, inf, 1);
507 %!  d = mpfr_vector_dot_d (0.5, 0, inf, 1);
508 %!  assert (l, 0);
509 %!  assert (u, inf);
510 %!  assert (isequaln (d, NaN));
511 %!test;
512 %!  x = reshape (1:24, 2, 3, 4);
513 %!  y = 2.*ones (2, 3, 4);
514 %!  [l u] = mpfr_vector_dot_d (x, y, x, y, 3);
515 %!  d = mpfr_vector_dot_d (0.5, x, y, 3);
516 %!  assert (l, [80, 96, 112; 88, 104, 120]);
517 %!  assert (u, [80, 96, 112; 88, 104, 120]);
518 %!  assert (d, [80, 96, 112; 88, 104, 120]);
519 
520 %!shared testdata
521 %! testdata = load (file_in_loadpath ("test/itl.mat"));
522 
523 %!test
524 %! # Scalar evaluation
525 %! testcases = testdata.NoSignal.double.dot_nearest;
526 %! for testcase = [testcases]'
527 %!   assert (isequaln (...
528 %!     mpfr_vector_dot_d (0.5, testcase.in{1}, testcase.in{2}, 2), ...
529 %!     testcase.out));
530 %! endfor
531 
532 %!test
533 %! # Vector evaluation
534 %! testcases = testdata.NoSignal.double.dot_nearest;
535 %! in1 = vertcat (testcases.in)(:, 1);
536 %! in1 = cell2mat (cellfun ("postpad", in1, {(max (cellfun ("numel", in1)))}, "UniformOutput", false));
537 %! in2 = vertcat (testcases.in)(:, 2);
538 %! in2 = cell2mat (cellfun ("postpad", in2, {(max (cellfun ("numel", in2)))}, "UniformOutput", false));
539 %! out = vertcat (testcases.out);
540 %! assert (isequaln (mpfr_vector_dot_d (0.5, in1, in2, 2), out));
541 
542 %!test
543 %! # Scalar evaluation
544 %! testcases = testdata.NoSignal.double.sum_sqr_nearest;
545 %! for testcase = [testcases]'
546 %!   assert (isequaln (...
547 %!     mpfr_vector_dot_d (0.5, testcase.in{1}, testcase.in{1}, 2), ...
548 %!     testcase.out));
549 %! endfor
550 
551 %!test
552 %! # Vector evaluation
553 %! testcases = testdata.NoSignal.double.sum_sqr_nearest;
554 %! in1 = vertcat (testcases.in);
555 %! in1 = cell2mat (cellfun ("postpad", in1, {(max (cellfun ("numel", in1)))}, "UniformOutput", false));
556 %! out = vertcat (testcases.out);
557 %! assert (isequaln (mpfr_vector_dot_d (0.5, in1, in1, 2), out));
558 
559 */
560