1# Copyright 2020 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Precision doubling arithmetic transform
16
17Following the approach of Dekker 1971
18(http://csclub.uwaterloo.ca/~pbarfuss/dekker1971.pdf).
19"""
20import decimal
21from functools import wraps
22import operator
23from typing import Any, Callable, Dict, Sequence
24
25import numpy as np
26
27from jax.tree_util import tree_flatten, tree_unflatten
28from jax.api_util import flatten_fun_nokwargs
29from jax import ad_util, core, lax, xla
30from jax._src.lax import lax as lax_internal
31from jax._src.util import unzip2, wrap_name
32import jax.numpy as jnp
33import jax.linear_util as lu
34
35class _Zeros:
36  def __repr__(self):
37    return "_zeros"
38_zero = _Zeros()
39
40class DoublingTracer(core.Tracer):
41  def __init__(self, trace, head, tail):
42    self._trace = trace
43    # TODO(vanderplas): check head/tail have matching shapes & dtypes
44    self.head = head
45    self.tail = tail
46
47  @property
48  def aval(self):
49    return core.raise_to_shaped(core.get_aval(self.head))
50  def full_lower(self):
51    if self.tail is None:
52      return core.full_lower(self.head)
53    else:
54      return self
55
56
57class DoublingTrace(core.Trace):
58  def pure(self, val: Any):
59    return DoublingTracer(self, val, jnp.zeros(jnp.shape(val), jnp.result_type(val)))
60
61  def lift(self, val: core.Tracer):
62    return DoublingTracer(self, val, jnp.zeros(jnp.shape(val), jnp.result_type(val)))
63
64  def sublift(self, val: DoublingTracer):
65    return DoublingTracer(self, val.head, val.tail)
66
67  def process_primitive(self, primitive, tracers, params):
68    func = doubling_rules.get(primitive, None)
69    if func is None:
70      raise NotImplementedError(f"primitive={primitive}")
71    out = func(*((t.head, t.tail) for t in tracers), **params)
72    # TODO: handle primitive.multiple_results
73    return DoublingTracer(self, *out)
74
75  def process_call(self, call_primitive, f, tracers, params):
76    assert call_primitive.multiple_results
77    heads, tails = unzip2((t.head, t.tail) for t in tracers)
78    nonzero_tails, in_tree_def = tree_flatten(tails)
79    f_double, out_tree_def = screen_nones(doubling_subtrace(f, self.main),
80                                          len(heads), in_tree_def)
81    name = params.get('name', f.__name__)
82    new_params = dict(params, name=wrap_name(name, 'doubledouble'),
83                      donated_invars=(False,) * (len(heads) + len(nonzero_tails)))
84    result = call_primitive.bind(f_double, *heads, *nonzero_tails, **new_params)
85    heads_out, tails_out = tree_unflatten(out_tree_def(), result)
86    return [DoublingTracer(self, h, t) for h, t in zip(heads_out, tails_out)]
87
88
89@lu.transformation
90def doubling_subtrace(main, heads, tails):
91  trace = DoublingTrace(main, core.cur_sublevel())
92  in_tracers = [DoublingTracer(trace, h, t) if t is not None else h
93                for h, t in zip(heads, tails)]
94  ans = yield in_tracers, {}
95  out_tracers = map(trace.full_raise, ans)
96  yield unzip2([(out_tracer.head, out_tracer.tail)
97                for out_tracer in out_tracers])
98
99
100@lu.transformation_with_aux
101def screen_nones(num_heads, in_tree_def, *heads_and_tails):
102  new_heads  = heads_and_tails[:num_heads]
103  new_tails = heads_and_tails[num_heads:]
104  new_tails = tree_unflatten(in_tree_def, new_tails)
105  head_out, tail_out = yield (new_heads, new_tails), {}
106  out_flat, tree_def = tree_flatten((head_out, tail_out))
107  yield out_flat, tree_def
108
109
110@lu.transformation
111def doubling_transform(*args):
112  with core.new_main(DoublingTrace) as main:
113    trace = DoublingTrace(main, core.cur_sublevel())
114    in_tracers = [DoublingTracer(trace, head, tail) for head, tail in args]
115    outputs = yield in_tracers, {}
116    if isinstance(outputs, Sequence):
117      out_tracers = map(trace.full_raise, outputs)
118      result = [(x.head, x.tail) for x in out_tracers]
119    else:
120      out_tracer = trace.full_raise(outputs)
121      result = (out_tracer.head, out_tracer.tail)
122  yield result
123
124
125def doubledouble(f):
126  @wraps(f)
127  def wrapped(*args):
128    args_flat, in_tree = tree_flatten(args)
129    f_flat, out_tree = flatten_fun_nokwargs(lu.wrap_init(f), in_tree)
130    arg_pairs = [(x, jnp.zeros_like(x)) for x in args_flat]
131    out_pairs_flat = doubling_transform(f_flat).call_wrapped(*arg_pairs)
132    out_flat = [head + tail for head, tail in out_pairs_flat]
133    out = tree_unflatten(out_tree(), out_flat)
134    return out
135  return wrapped
136
137
138doubling_rules: Dict[core.Primitive, Callable] = {}
139
140def _mul_const(dtype):
141  _nmant = jnp.finfo(dtype).nmant
142  return jnp.array((2 << (_nmant - _nmant // 2)) + 1, dtype=dtype)
143
144def _normalize(x, y):
145  z = jnp.where(jnp.isinf(x), x, x + y)
146  zz = jnp.where(
147    lax.abs(x) > lax.abs(y),
148    x - z + y,
149    y - z + x,
150  )
151  return z, jnp.where(jnp.isinf(z), 0, zz)
152
153def _abs2(x):
154  x, xx = x
155  sign = jnp.where(lax.sign(x) == lax.sign(xx), 1, -1)
156  return (lax.abs(x), sign * lax.abs(xx))
157doubling_rules[lax.abs_p] = _abs2
158
159def _neg2(x):
160  return (-x[0], -x[1])
161doubling_rules[lax.neg_p] = _neg2
162
163def _add2(x, y):
164  (x, xx), (y, yy) = x, y
165  r = x + y
166  s = jnp.where(
167    lax.abs(x) > lax.abs(y),
168    x - r + y + yy + xx,
169    y - r + x + xx + yy,
170  )
171  z = r + s
172  zz = r - z + s
173  return (z, zz)
174doubling_rules[lax.add_p] = _add2
175
176def _sub2(x, y):
177  (x, xx), (y, yy) = x, y
178  r = x - y
179  s = jnp.where(
180    lax.abs(x) > lax.abs(y),
181    x - r - y - yy + xx,
182    -y - r + x + xx - yy,
183  )
184  z = r + s
185  zz = r - z + s
186  return (z, zz)
187doubling_rules[lax.sub_p] = _sub2
188
189def _mul12(x, y):
190  dtype = jnp.result_type(x, y)
191  K = _mul_const(dtype)
192  p = x * K
193  hx = x - p + p
194  tx = x - hx
195  p = y * K
196  hy = y - p + p
197  ty = y - hy
198  p = hx * hy
199  q = hx * ty + tx * hy
200  z = p + q
201  zz = p - z + q + tx * ty
202  return z, zz
203
204def _mul2(x, y):
205  (x, xx), (y, yy) = x, y
206  c, cc = _mul12(x, y)
207  cc = x * yy + xx * y + cc
208  z = c + cc
209  zz = c - z + cc
210  return (z, zz)
211doubling_rules[lax.mul_p] = _mul2
212
213def _div2(x, y):
214  (x, xx), (y, yy) = x, y
215  c = x / y
216  u, uu = _mul12(c, y)
217  cc = (x - u - uu + xx - c * yy) / y
218  z = c + cc
219  zz = c - z + cc
220  return z, zz
221doubling_rules[lax.div_p] = _div2
222
223def _sqrt2(x):
224  x, xx = x
225  c = lax.sqrt(x)
226  u, uu = _mul12(c, c)
227  cc = (x - u - uu + xx) * 0.5 / c
228  y = c + cc
229  yy = c - y + cc
230  return y, yy
231doubling_rules[lax.sqrt_p] = _sqrt2
232
233
234def _def_inequality(prim, op):
235  def transformed(x, y):
236    z, zz = _sub2(x, y)
237    return op(z + zz, 0), None
238  doubling_rules[prim] = transformed
239
240_def_inequality(lax.gt_p, operator.gt)
241_def_inequality(lax.ge_p, operator.ge)
242_def_inequality(lax.lt_p, operator.lt)
243_def_inequality(lax.le_p, operator.le)
244_def_inequality(lax.eq_p, operator.eq)
245_def_inequality(lax.ne_p, operator.ne)
246
247def _convert_element_type(operand, new_dtype):
248  head, tail = operand
249  head = lax.convert_element_type_p.bind(head, new_dtype=new_dtype)
250  if tail is not None:
251    tail = lax.convert_element_type_p.bind(tail, new_dtype=new_dtype)
252  if jnp.issubdtype(new_dtype, jnp.floating):
253    if tail is None:
254      tail = jnp.zeros_like(head)
255  elif tail is not None:
256    head = head + tail
257    tail = None
258  return (head, tail)
259doubling_rules[lax.convert_element_type_p] = _convert_element_type
260
261def _add_jaxvals(xs, ys):
262  # return ad_util.jaxval_adders[type(xs[0])](xs, ys)
263  return _add2(xs, ys)
264doubling_rules[ad_util.add_jaxvals_p] = _add_jaxvals
265
266def _def_passthrough(prim, argnums=(0,)):
267  def transformed(*args, **kwargs):
268    return (
269      prim.bind(*(arg[0] if i in argnums else arg for i, arg in enumerate(args)), **kwargs),
270      prim.bind(*(arg[1] if i in argnums else arg for i, arg in enumerate(args)), **kwargs)
271    )
272  doubling_rules[prim] = transformed
273
274_def_passthrough(lax.select_p, (0, 1, 2))
275_def_passthrough(lax.broadcast_in_dim_p)
276_def_passthrough(xla.device_put_p)
277try:
278  _def_passthrough(lax_internal.tie_in_p, (0, 1))
279except AttributeError:
280  pass
281
282
283class _DoubleDouble:
284  """DoubleDouble class with overloaded operators."""
285  __slots__ = ["head", "tail"]
286
287  def __init__(self, val, dtype=None):
288    if isinstance(val, tuple):
289      head, tail = val
290    elif isinstance(val, str):
291      dtype = jnp.dtype(dtype or 'float64').type
292      val = decimal.Decimal(val)
293      head = dtype(val)
294      tail = 0 if np.isinf(head) else dtype(val - decimal.Decimal(float(head)))
295    elif isinstance(val, int):
296      dtype = jnp.dtype(dtype or 'float64').type
297      head = dtype(val)
298      tail = 0 if np.isinf(head) else dtype(val - int(head))
299    elif isinstance(val, _DoubleDouble):
300      head, tail = val.head, val.tail
301    else:
302      head, tail = val, jnp.zeros_like(val)
303    dtype = dtype or jnp.result_type(head, tail)
304    head = jnp.asarray(head, dtype=dtype)
305    tail = jnp.asarray(tail, dtype=dtype)
306    self.head, self.tail = _normalize(head, tail)
307
308  def normalize(self):
309    """Return a normalized copy of self."""
310    return self._wrap(_normalize(self.head, self.tail))
311
312  @property
313  def dtype(self):
314    return self.head.dtype
315
316  def to_array(self, dtype=None):
317    head, tail = self._tup
318    if dtype is not None:
319      head = head.astype(dtype)
320      tail = tail.astype(dtype)
321    return head + jnp.where(jnp.isinf(head), 0, tail)
322
323  def __repr__(self):
324    return f"{self.__class__.__name__}({self.head}, {self.tail})"
325
326  @property
327  def _tup(self):
328    return self.head, self.tail
329
330  def _wrap(self, other):
331    return self.__class__(other, dtype=self.dtype)
332
333  def __abs__(self):
334    return self._wrap(_abs2(self._tup))
335
336  def __neg__(self):
337    return self._wrap(_neg2(self._tup))
338
339  def __add__(self, other):
340    return self._wrap(_add2(self._tup, self._wrap(other)._tup))
341
342  def __sub__(self, other):
343    return self._wrap(_sub2(self._tup, self._wrap(other)._tup))
344
345  def __mul__(self, other):
346    return self._wrap(_mul2(self._tup, self._wrap(other)._tup))
347
348  def __truediv__(self, other):
349    return self._wrap(_div2(self._tup, self._wrap(other)._tup))
350
351  def __radd__(self, other):
352    return self._wrap(_add2(self._wrap(other)._tup, self._tup))
353
354  def __rsub__(self, other):
355    return self._wrap(_sub2(self._wrap(other)._tup, self._tup))
356
357  def __rmul__(self, other):
358    return self._wrap(_mul2(self._wrap(other)._tup, self._tup))
359
360  def __rtruediv__(self, other):
361    return self._wrap(_div2(self._wrap(other)._tup, self._tup))
362
363  def __lt__(self, other):
364    return (self - other).to_array() < 0
365
366  def __le__(self, other):
367    return (self - other).to_array() <= 0
368
369  def __gt__(self, other):
370    return (self - other).to_array() > 0
371
372  def __ge__(self, other):
373    return (self - other).to_array() >= 0
374
375  def __eq__(self, other):
376    return (self - other).to_array() == 0
377
378  def __ne__(self, other):
379    return (self - other).to_array() != 0
380