1# Copyright 2020 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Precision doubling arithmetic transform 16 17Following the approach of Dekker 1971 18(http://csclub.uwaterloo.ca/~pbarfuss/dekker1971.pdf). 19""" 20import decimal 21from functools import wraps 22import operator 23from typing import Any, Callable, Dict, Sequence 24 25import numpy as np 26 27from jax.tree_util import tree_flatten, tree_unflatten 28from jax.api_util import flatten_fun_nokwargs 29from jax import ad_util, core, lax, xla 30from jax._src.lax import lax as lax_internal 31from jax._src.util import unzip2, wrap_name 32import jax.numpy as jnp 33import jax.linear_util as lu 34 35class _Zeros: 36 def __repr__(self): 37 return "_zeros" 38_zero = _Zeros() 39 40class DoublingTracer(core.Tracer): 41 def __init__(self, trace, head, tail): 42 self._trace = trace 43 # TODO(vanderplas): check head/tail have matching shapes & dtypes 44 self.head = head 45 self.tail = tail 46 47 @property 48 def aval(self): 49 return core.raise_to_shaped(core.get_aval(self.head)) 50 def full_lower(self): 51 if self.tail is None: 52 return core.full_lower(self.head) 53 else: 54 return self 55 56 57class DoublingTrace(core.Trace): 58 def pure(self, val: Any): 59 return DoublingTracer(self, val, jnp.zeros(jnp.shape(val), jnp.result_type(val))) 60 61 def lift(self, val: core.Tracer): 62 return DoublingTracer(self, val, jnp.zeros(jnp.shape(val), jnp.result_type(val))) 63 64 def sublift(self, val: DoublingTracer): 65 return DoublingTracer(self, val.head, val.tail) 66 67 def process_primitive(self, primitive, tracers, params): 68 func = doubling_rules.get(primitive, None) 69 if func is None: 70 raise NotImplementedError(f"primitive={primitive}") 71 out = func(*((t.head, t.tail) for t in tracers), **params) 72 # TODO: handle primitive.multiple_results 73 return DoublingTracer(self, *out) 74 75 def process_call(self, call_primitive, f, tracers, params): 76 assert call_primitive.multiple_results 77 heads, tails = unzip2((t.head, t.tail) for t in tracers) 78 nonzero_tails, in_tree_def = tree_flatten(tails) 79 f_double, out_tree_def = screen_nones(doubling_subtrace(f, self.main), 80 len(heads), in_tree_def) 81 name = params.get('name', f.__name__) 82 new_params = dict(params, name=wrap_name(name, 'doubledouble'), 83 donated_invars=(False,) * (len(heads) + len(nonzero_tails))) 84 result = call_primitive.bind(f_double, *heads, *nonzero_tails, **new_params) 85 heads_out, tails_out = tree_unflatten(out_tree_def(), result) 86 return [DoublingTracer(self, h, t) for h, t in zip(heads_out, tails_out)] 87 88 89@lu.transformation 90def doubling_subtrace(main, heads, tails): 91 trace = DoublingTrace(main, core.cur_sublevel()) 92 in_tracers = [DoublingTracer(trace, h, t) if t is not None else h 93 for h, t in zip(heads, tails)] 94 ans = yield in_tracers, {} 95 out_tracers = map(trace.full_raise, ans) 96 yield unzip2([(out_tracer.head, out_tracer.tail) 97 for out_tracer in out_tracers]) 98 99 100@lu.transformation_with_aux 101def screen_nones(num_heads, in_tree_def, *heads_and_tails): 102 new_heads = heads_and_tails[:num_heads] 103 new_tails = heads_and_tails[num_heads:] 104 new_tails = tree_unflatten(in_tree_def, new_tails) 105 head_out, tail_out = yield (new_heads, new_tails), {} 106 out_flat, tree_def = tree_flatten((head_out, tail_out)) 107 yield out_flat, tree_def 108 109 110@lu.transformation 111def doubling_transform(*args): 112 with core.new_main(DoublingTrace) as main: 113 trace = DoublingTrace(main, core.cur_sublevel()) 114 in_tracers = [DoublingTracer(trace, head, tail) for head, tail in args] 115 outputs = yield in_tracers, {} 116 if isinstance(outputs, Sequence): 117 out_tracers = map(trace.full_raise, outputs) 118 result = [(x.head, x.tail) for x in out_tracers] 119 else: 120 out_tracer = trace.full_raise(outputs) 121 result = (out_tracer.head, out_tracer.tail) 122 yield result 123 124 125def doubledouble(f): 126 @wraps(f) 127 def wrapped(*args): 128 args_flat, in_tree = tree_flatten(args) 129 f_flat, out_tree = flatten_fun_nokwargs(lu.wrap_init(f), in_tree) 130 arg_pairs = [(x, jnp.zeros_like(x)) for x in args_flat] 131 out_pairs_flat = doubling_transform(f_flat).call_wrapped(*arg_pairs) 132 out_flat = [head + tail for head, tail in out_pairs_flat] 133 out = tree_unflatten(out_tree(), out_flat) 134 return out 135 return wrapped 136 137 138doubling_rules: Dict[core.Primitive, Callable] = {} 139 140def _mul_const(dtype): 141 _nmant = jnp.finfo(dtype).nmant 142 return jnp.array((2 << (_nmant - _nmant // 2)) + 1, dtype=dtype) 143 144def _normalize(x, y): 145 z = jnp.where(jnp.isinf(x), x, x + y) 146 zz = jnp.where( 147 lax.abs(x) > lax.abs(y), 148 x - z + y, 149 y - z + x, 150 ) 151 return z, jnp.where(jnp.isinf(z), 0, zz) 152 153def _abs2(x): 154 x, xx = x 155 sign = jnp.where(lax.sign(x) == lax.sign(xx), 1, -1) 156 return (lax.abs(x), sign * lax.abs(xx)) 157doubling_rules[lax.abs_p] = _abs2 158 159def _neg2(x): 160 return (-x[0], -x[1]) 161doubling_rules[lax.neg_p] = _neg2 162 163def _add2(x, y): 164 (x, xx), (y, yy) = x, y 165 r = x + y 166 s = jnp.where( 167 lax.abs(x) > lax.abs(y), 168 x - r + y + yy + xx, 169 y - r + x + xx + yy, 170 ) 171 z = r + s 172 zz = r - z + s 173 return (z, zz) 174doubling_rules[lax.add_p] = _add2 175 176def _sub2(x, y): 177 (x, xx), (y, yy) = x, y 178 r = x - y 179 s = jnp.where( 180 lax.abs(x) > lax.abs(y), 181 x - r - y - yy + xx, 182 -y - r + x + xx - yy, 183 ) 184 z = r + s 185 zz = r - z + s 186 return (z, zz) 187doubling_rules[lax.sub_p] = _sub2 188 189def _mul12(x, y): 190 dtype = jnp.result_type(x, y) 191 K = _mul_const(dtype) 192 p = x * K 193 hx = x - p + p 194 tx = x - hx 195 p = y * K 196 hy = y - p + p 197 ty = y - hy 198 p = hx * hy 199 q = hx * ty + tx * hy 200 z = p + q 201 zz = p - z + q + tx * ty 202 return z, zz 203 204def _mul2(x, y): 205 (x, xx), (y, yy) = x, y 206 c, cc = _mul12(x, y) 207 cc = x * yy + xx * y + cc 208 z = c + cc 209 zz = c - z + cc 210 return (z, zz) 211doubling_rules[lax.mul_p] = _mul2 212 213def _div2(x, y): 214 (x, xx), (y, yy) = x, y 215 c = x / y 216 u, uu = _mul12(c, y) 217 cc = (x - u - uu + xx - c * yy) / y 218 z = c + cc 219 zz = c - z + cc 220 return z, zz 221doubling_rules[lax.div_p] = _div2 222 223def _sqrt2(x): 224 x, xx = x 225 c = lax.sqrt(x) 226 u, uu = _mul12(c, c) 227 cc = (x - u - uu + xx) * 0.5 / c 228 y = c + cc 229 yy = c - y + cc 230 return y, yy 231doubling_rules[lax.sqrt_p] = _sqrt2 232 233 234def _def_inequality(prim, op): 235 def transformed(x, y): 236 z, zz = _sub2(x, y) 237 return op(z + zz, 0), None 238 doubling_rules[prim] = transformed 239 240_def_inequality(lax.gt_p, operator.gt) 241_def_inequality(lax.ge_p, operator.ge) 242_def_inequality(lax.lt_p, operator.lt) 243_def_inequality(lax.le_p, operator.le) 244_def_inequality(lax.eq_p, operator.eq) 245_def_inequality(lax.ne_p, operator.ne) 246 247def _convert_element_type(operand, new_dtype): 248 head, tail = operand 249 head = lax.convert_element_type_p.bind(head, new_dtype=new_dtype) 250 if tail is not None: 251 tail = lax.convert_element_type_p.bind(tail, new_dtype=new_dtype) 252 if jnp.issubdtype(new_dtype, jnp.floating): 253 if tail is None: 254 tail = jnp.zeros_like(head) 255 elif tail is not None: 256 head = head + tail 257 tail = None 258 return (head, tail) 259doubling_rules[lax.convert_element_type_p] = _convert_element_type 260 261def _add_jaxvals(xs, ys): 262 # return ad_util.jaxval_adders[type(xs[0])](xs, ys) 263 return _add2(xs, ys) 264doubling_rules[ad_util.add_jaxvals_p] = _add_jaxvals 265 266def _def_passthrough(prim, argnums=(0,)): 267 def transformed(*args, **kwargs): 268 return ( 269 prim.bind(*(arg[0] if i in argnums else arg for i, arg in enumerate(args)), **kwargs), 270 prim.bind(*(arg[1] if i in argnums else arg for i, arg in enumerate(args)), **kwargs) 271 ) 272 doubling_rules[prim] = transformed 273 274_def_passthrough(lax.select_p, (0, 1, 2)) 275_def_passthrough(lax.broadcast_in_dim_p) 276_def_passthrough(xla.device_put_p) 277try: 278 _def_passthrough(lax_internal.tie_in_p, (0, 1)) 279except AttributeError: 280 pass 281 282 283class _DoubleDouble: 284 """DoubleDouble class with overloaded operators.""" 285 __slots__ = ["head", "tail"] 286 287 def __init__(self, val, dtype=None): 288 if isinstance(val, tuple): 289 head, tail = val 290 elif isinstance(val, str): 291 dtype = jnp.dtype(dtype or 'float64').type 292 val = decimal.Decimal(val) 293 head = dtype(val) 294 tail = 0 if np.isinf(head) else dtype(val - decimal.Decimal(float(head))) 295 elif isinstance(val, int): 296 dtype = jnp.dtype(dtype or 'float64').type 297 head = dtype(val) 298 tail = 0 if np.isinf(head) else dtype(val - int(head)) 299 elif isinstance(val, _DoubleDouble): 300 head, tail = val.head, val.tail 301 else: 302 head, tail = val, jnp.zeros_like(val) 303 dtype = dtype or jnp.result_type(head, tail) 304 head = jnp.asarray(head, dtype=dtype) 305 tail = jnp.asarray(tail, dtype=dtype) 306 self.head, self.tail = _normalize(head, tail) 307 308 def normalize(self): 309 """Return a normalized copy of self.""" 310 return self._wrap(_normalize(self.head, self.tail)) 311 312 @property 313 def dtype(self): 314 return self.head.dtype 315 316 def to_array(self, dtype=None): 317 head, tail = self._tup 318 if dtype is not None: 319 head = head.astype(dtype) 320 tail = tail.astype(dtype) 321 return head + jnp.where(jnp.isinf(head), 0, tail) 322 323 def __repr__(self): 324 return f"{self.__class__.__name__}({self.head}, {self.tail})" 325 326 @property 327 def _tup(self): 328 return self.head, self.tail 329 330 def _wrap(self, other): 331 return self.__class__(other, dtype=self.dtype) 332 333 def __abs__(self): 334 return self._wrap(_abs2(self._tup)) 335 336 def __neg__(self): 337 return self._wrap(_neg2(self._tup)) 338 339 def __add__(self, other): 340 return self._wrap(_add2(self._tup, self._wrap(other)._tup)) 341 342 def __sub__(self, other): 343 return self._wrap(_sub2(self._tup, self._wrap(other)._tup)) 344 345 def __mul__(self, other): 346 return self._wrap(_mul2(self._tup, self._wrap(other)._tup)) 347 348 def __truediv__(self, other): 349 return self._wrap(_div2(self._tup, self._wrap(other)._tup)) 350 351 def __radd__(self, other): 352 return self._wrap(_add2(self._wrap(other)._tup, self._tup)) 353 354 def __rsub__(self, other): 355 return self._wrap(_sub2(self._wrap(other)._tup, self._tup)) 356 357 def __rmul__(self, other): 358 return self._wrap(_mul2(self._wrap(other)._tup, self._tup)) 359 360 def __rtruediv__(self, other): 361 return self._wrap(_div2(self._wrap(other)._tup, self._tup)) 362 363 def __lt__(self, other): 364 return (self - other).to_array() < 0 365 366 def __le__(self, other): 367 return (self - other).to_array() <= 0 368 369 def __gt__(self, other): 370 return (self - other).to_array() > 0 371 372 def __ge__(self, other): 373 return (self - other).to_array() >= 0 374 375 def __eq__(self, other): 376 return (self - other).to_array() == 0 377 378 def __ne__(self, other): 379 return (self - other).to_array() != 0 380