1# Copyright 2018 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15 16from collections import defaultdict, deque 17import itertools as it 18import operator as op 19from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Type, 20 Tuple, Union, NamedTuple) 21from warnings import warn 22 23from absl import logging 24import numpy as np 25 26from ..config import flags, bool_env, config 27from .. import core 28from .. import ad_util 29from .. import dtypes 30from .. import lazy 31from .. import linear_util as lu 32from jax._src import source_info_util 33from ..abstract_arrays import (make_shaped_array, array_types) 34from ..core import (ConcreteArray, ShapedArray, AbstractToken, 35 Literal, pp_eqn_compact, raise_to_shaped, abstract_token) 36from jax._src.pprint_util import pp 37from .._src.util import (partial, partialmethod, cache, prod, unzip2, 38 extend_name_stack, wrap_name, safe_zip, safe_map) 39from ..lib import xla_bridge as xb 40from ..lib import xla_client as xc 41from . import partial_eval as pe 42from . import ad 43from . import masking 44 45map, unsafe_map = safe_map, map 46zip, unsafe_zip = safe_zip, zip 47 48xe = xc._xla 49xops = xc._xla.ops 50 51# Types 52Backend = Any # xc.LocalBackend (why does mypy not like this?) 53Device = Any # xc.Device 54PyLocalBuffer = Any 55 56XlaOp = Any # xla_extension.XlaOp 57XlaShape = Any # xla_client.Shape 58XlaComputationBuilder = Any # xla_bridge._JaxComputationBuilder 59XlaExecutable = Any # xla_extension.LocalExecutable 60 61FLAGS = flags.FLAGS 62flags.DEFINE_bool('jax_debug_nans', 63 bool_env('JAX_DEBUG_NANS', False), 64 'Add nan checks to every operation.') 65flags.DEFINE_bool('jax_debug_infs', 66 bool_env('JAX_DEBUG_INFS', False), 67 'Add inf checks to every operation.') 68flags.DEFINE_bool('jax_log_compiles', 69 bool_env('JAX_LOG_COMPILES', False), 70 'Print a message each time a `jit` computation is compiled.') 71 72# This flag is set on exit; no logging should be attempted 73_on_exit = False 74 75def identity(x): return x 76 77_scalar_types = dtypes.python_scalar_dtypes.keys() 78 79# unit representation 80def _make_unit_constant(c): return xb.constant(c, np.zeros((), dtype=np.dtype('bool'))) 81def _make_unit_shape(_): return (xc.Shape.array_shape(np.dtype('bool'), ()),) 82def _device_put_unit(_, device): 83 backend = xb.get_device_backend(device) 84 return (backend.buffer_from_pyval(np.zeros((), dtype=np.dtype('bool')), 85 device),) 86def _make_array_shape(a): 87 if a.dtype is dtypes.float0: 88 return (xc.Shape.array_shape(np.dtype('bool'), a.shape),) 89 else: 90 return (xc.Shape.array_shape(a.dtype, a.shape),) 91 92### handlers 93 94xb.register_constant_handler(core.Unit, lambda c, *_: _make_unit_constant(c)) 95 96def aval_to_xla_shapes(aval): 97 try: 98 return xla_shape_handlers[type(aval)](aval) 99 except KeyError as err: 100 raise TypeError(f"No xla_shape_handler for type: {type(aval)}") from err 101 102xla_shape_handlers: Dict[Type[core.AbstractValue], Callable] = { 103 core.AbstractUnit: _make_unit_shape, 104 ShapedArray: _make_array_shape, 105 ConcreteArray: _make_array_shape, 106} 107 108def aval_to_result_handler(device: Optional[Device], aval: core.AbstractValue) -> Callable: 109 try: 110 return xla_result_handlers[type(aval)](device, aval) 111 except KeyError as err: 112 raise TypeError(f"No xla_result_handler for type: {type(aval)}") from err 113 114def array_result_handler(device: Optional[Device], aval: core.ShapedArray): 115 if aval.dtype is dtypes.float0: 116 return lambda _: np.zeros(aval.shape, dtypes.float0) 117 return partial(make_device_array, raise_to_shaped(aval), device, 118 lazy.array(aval.shape)) 119 120 121xla_result_handlers: Dict[Type[core.AbstractValue], Callable[..., Callable]] = { 122 core.AbstractUnit: lambda _, __: lambda _: core.unit, 123 ShapedArray: array_result_handler, 124 ConcreteArray: array_result_handler, 125} 126 127def device_put(x, device: Optional[Device] = None) -> Tuple[Any]: 128 x = canonicalize_dtype(x) 129 try: 130 return device_put_handlers[type(x)](x, device) 131 except KeyError as err: 132 raise TypeError(f"No device_put handler for type: {type(x)}") from err 133 134def _device_put_array(x, device: Optional[Device]): 135 backend = xb.get_device_backend(device) 136 if x.dtype is dtypes.float0: 137 x = np.zeros(x.shape, dtype=np.dtype(bool)) 138 return (backend.buffer_from_pyval(x, device),) 139 140def _device_put_scalar(x, device): 141 return _device_put_array(dtypes.coerce_to_array(x), device) 142 143device_put_handlers: Dict[Any, Callable[[Any, Optional[Device]], Tuple[Any]]] = { 144 core.Unit: _device_put_unit 145} 146device_put_handlers.update((t, _device_put_array) for t in array_types) 147device_put_handlers.update((t, _device_put_scalar) for t in _scalar_types) 148 149# TODO(mattjj): try to remove this canonicalize_dtype stuff 150def canonicalize_dtype(x): 151 typ = type(x) 152 handler = canonicalize_dtype_handlers.get(typ) 153 if handler: return handler(x) 154 for typ in typ.mro(): 155 handler = canonicalize_dtype_handlers.get(typ) 156 if handler: return handler(x) 157 raise TypeError(f"No canonicalize_dtype handler for type: {type(x)}") 158 159def _canonicalize_ndarray_dtype(x): 160 return np.asarray(x, dtypes.canonicalize_dtype(dtypes.result_type(x))) 161 162def _canonicalize_python_scalar_dtype(typ, x): 163 return np.asarray( 164 x, dtypes.canonicalize_dtype(dtypes.python_scalar_dtypes[typ])) 165 166canonicalize_dtype_handlers: Dict[Any, Callable] = {core.Unit: identity} 167canonicalize_dtype_handlers.update( 168 (t, _canonicalize_ndarray_dtype) for t in array_types) 169canonicalize_dtype_handlers.update( 170 (t, partial(_canonicalize_python_scalar_dtype, t)) for t in _scalar_types) 171 172def abstractify(x) -> core.AbstractValue: 173 typ = type(x) 174 aval_fn = pytype_aval_mappings.get(typ) 175 if aval_fn: return aval_fn(x) 176 for typ in typ.mro(): 177 aval_fn = pytype_aval_mappings.get(typ) 178 if aval_fn: return aval_fn(x) 179 raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type") 180 181def _make_abstract_python_scalar(typ, _): 182 return ShapedArray((), dtypes.python_scalar_dtypes[typ], weak_type=True) 183 184pytype_aval_mappings: Dict[Any, Callable[[Any], core.AbstractValue]] = { 185 core.Unit: lambda _: core.abstract_unit, 186} 187pytype_aval_mappings.update((t, make_shaped_array) for t in array_types) 188pytype_aval_mappings.update( 189 (t, partial(_make_abstract_python_scalar, t)) for t in _scalar_types) 190 191# We can optionally set a Jaxpr rewriter that can be applied just before 192# compilation. This mechanism is used for compiling id_tap, we can 193# remove it once we bring the id_tap implementation into the core. 194outfeed_rewriter: Optional[Callable[[core.Jaxpr], core.Jaxpr]] = None 195def apply_outfeed_rewriter(jaxpr: core.Jaxpr) -> core.Jaxpr: 196 if outfeed_rewriter is not None: 197 return outfeed_rewriter(jaxpr) 198 else: 199 return jaxpr 200 201outfeed_primitives: Set[core.Primitive] = set() 202def jaxpr_uses_outfeed(jaxpr: core.Jaxpr) -> bool: 203 """Finds if there are outfeed primitives anywhere inside a Jaxpr.""" 204 return any(primitive_uses_outfeed(eqn.primitive, eqn.params) 205 for eqn in jaxpr.eqns) 206 207def _param_uses_outfeed(param): 208 if type(param) is core.Jaxpr: 209 if jaxpr_uses_outfeed(param): 210 return True 211 elif type(param) is core.ClosedJaxpr: 212 if jaxpr_uses_outfeed(param.jaxpr): 213 return True 214 return False 215 216def primitive_uses_outfeed(prim: core.Primitive, params: Dict) -> bool: 217 if prim in outfeed_primitives: 218 return True 219 for param in params.values(): 220 if isinstance(param, tuple): 221 if any(unsafe_map(_param_uses_outfeed, param)): 222 return True 223 elif _param_uses_outfeed(param): 224 return True 225 return False 226 227### op-by-op execution 228 229def arg_spec(x): 230 aval = abstractify(x) 231 try: 232 return aval, x._device 233 except: 234 return aval, None 235 236def apply_primitive(prim, *args, **params): 237 """Impl rule that compiles and runs a single primitive 'prim' using XLA.""" 238 compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params) 239 return compiled_fun(*args) 240 241 242def _partition_outputs(avals, outs): 243 nouts = [aval._num_buffers for aval in avals] 244 if not core.skip_checks: 245 assert sum(nouts) == len(outs), f"Internal error: sum(nouts)={sum(nouts)} should equal len(outs)={len(outs)}." 246 outs = iter(outs) 247 return [[next(outs) for _ in range(nout)] for nout in nouts] 248 249 250@cache() 251def xla_primitive_callable(prim, *arg_specs: Tuple[core.AbstractValue, 252 Optional[Device]], **params): 253 avals, arg_devices = unzip2(arg_specs) 254 donated_invars = (False,) * len(arg_specs) 255 device = _device_from_arg_devices(arg_devices) 256 backend = xb.get_device_backend(device) 257 if primitive_uses_outfeed(prim, params): 258 # We use the _xla_callable path, where we pre-process the primitives 259 def prim_fun(*args): 260 return prim.bind(*args, **params) 261 return _xla_callable(lu.wrap_init(prim_fun), device, None, "prim", donated_invars, 262 *arg_specs) 263 aval_out = prim.abstract_eval(*avals, **params) 264 if not prim.multiple_results: 265 handle_result = aval_to_result_handler(device, aval_out) 266 else: 267 handlers = map(partial(aval_to_result_handler, device), aval_out) 268 handle_result = lambda *bufs:\ 269 tuple(handler(*bs) for handler, bs in zip(handlers, _partition_outputs(aval_out, bufs))) 270 tuple_args = len(avals) > 100 271 if prim in initial_style_translations: 272 nreps = initial_style_primitive_replicas(params) 273 else: 274 nreps = 1 275 276 if nreps > xb.device_count(backend): 277 raise ValueError( 278 f"compiling a primitive computation `{prim}` that requires {nreps} " 279 f"replicas, but only {xb.device_count(backend)} XLA devices are " 280 f"available on backend {backend.platform}.") 281 built_c = primitive_computation(prim, AxisEnv(nreps, (), ()), backend, 282 tuple_args, *avals, **params) 283 options = xb.get_compile_options( 284 num_replicas=nreps, 285 num_partitions=1, 286 device_assignment=device and (device.id,)) 287 options.parameter_is_tupled_arguments = tuple_args 288 compiled = backend_compile(backend, built_c, options) 289 if nreps == 1: 290 return partial(_execute_compiled_primitive, prim, compiled, handle_result) 291 else: 292 return partial(_execute_replicated_primitive, prim, compiled, handle_result) 293 294def _device_from_arg_devices(devices: Sequence[Optional[Device]]) -> Optional[Device]: 295 """Given devices of inputs, determine where to perform a computation. 296 297 Args: 298 devices: list where each element is a either a `Device` instance or `None`. 299 Returns: 300 A `Device` instance or None. 301 Raises: 302 ValueError if input devices are inconsistent. 303 """ 304 try: 305 device, = {d for d in devices if d is not None} or (None,) 306 return device 307 except ValueError as err: 308 msg = "primitive arguments must be colocated on the same device, got {}" 309 raise ValueError(msg.format(", ".join(map(str, devices)))) from err 310 311@cache() 312def primitive_computation(prim, axis_env, backend, tuple_args, *avals, **params): 313 c = xb.make_computation_builder(f"primitive_computation_{prim.name}") 314 c.set_op_metadata(xc.OpMetadata( 315 op_type=prim.name, 316 op_name=str(pp_eqn_compact(prim.name, params)))) 317 platform = xb.get_backend(backend).platform 318 xla_args, _ = _xla_callable_args(c, avals, tuple_args) 319 # return val always set as a side-effect on c 320 if prim in backend_specific_translations[platform]: 321 rule = backend_specific_translations[platform][prim] 322 ans = rule(c, *xla_args, **params) 323 elif prim in translations: 324 rule = translations[prim] 325 ans = rule(c, *xla_args, **params) 326 elif prim in translations_with_avals: 327 rule = translations_with_avals[prim] 328 ans = rule(c, avals, xla_args, params) 329 elif prim in initial_style_translations: 330 rule = initial_style_translations[prim] 331 ans = rule(c, axis_env, extend_name_stack(prim.name), avals, backend, 332 *xla_args, **params) 333 else: 334 raise NotImplementedError(f"XLA translation rule for {prim} not found") 335 assert isinstance(ans, xe.XlaOp) 336 c.clear_op_metadata() 337 try: 338 return c.build(ans) 339 except RuntimeError as e: 340 msg = (" ".join(map(str, e.args)) + "\n" 341 "This is a bug in JAX's shape-checking rules; please report it!\n" 342 "https://github.com/google/jax/issues\n") 343 raise RuntimeError(msg) from e 344 345def primitive_subcomputation(prim, *avals, **params): 346 axis_env = AxisEnv(1, (), ()) 347 return primitive_computation(prim, axis_env, None, False, *avals, **params) 348 349def backend_compile(backend, built_c, options): 350 # we use a separate function call to ensure that XLA compilation appears 351 # separately in Python profiling results 352 return backend.compile(built_c, compile_options=options) 353 354def _execute_compiled_primitive(prim, compiled, result_handler, *args): 355 device, = compiled.local_devices() 356 input_bufs = list(it.chain.from_iterable(device_put(x, device) for x in args if x is not token)) 357 out_bufs = compiled.execute(input_bufs) 358 check_special(prim, out_bufs) 359 return result_handler(*out_bufs) 360 361def _execute_replicated_primitive(prim, compiled, result_handler, *args): 362 input_bufs = [ 363 list(it.chain.from_iterable(device_put(x, device) for x in args if x is not token)) 364 for device in compiled.local_devices()] 365 out_bufs = compiled.execute_on_local_devices(input_bufs)[0] 366 return result_handler(*out_bufs) 367 368 369def check_special(prim, bufs): 370 if FLAGS.jax_debug_infs or FLAGS.jax_debug_nans: 371 for buf in bufs: 372 _check_special(prim.name, buf.xla_shape(), buf) 373 374def _check_special(name, xla_shape, buf): 375 assert not xla_shape.is_tuple() 376 if dtypes.issubdtype(xla_shape.element_type(), np.inexact): 377 if FLAGS.jax_debug_nans and np.any(np.isnan(buf.to_py())): 378 raise FloatingPointError(f"invalid value (nan) encountered in {name}") 379 if FLAGS.jax_debug_infs and np.any(np.isinf(buf.to_py())): 380 raise FloatingPointError(f"invalid value (inf) encountered in {name}") 381 382### compiling jaxprs 383 384def prefetch(x): 385 if isinstance(x, DeviceArray): 386 x.copy_to_host_async() 387 return x 388 389def jaxpr_literals(jaxpr): 390 """Generates all the literals inside a jaxpr, including nested subjaxprs.""" 391 for eqn in jaxpr.eqns: 392 for v in eqn.invars: 393 if type(v) is core.Literal: 394 yield v.val 395 for subjaxpr in core.subjaxprs(jaxpr): 396 yield from jaxpr_literals(subjaxpr) 397 398 399def _flatmap(func: Callable, vars: Sequence): 400 return list(it.chain.from_iterable(map(func, vars))) 401 402def _partitionmap(func: Callable, vars: Sequence, nodes: Sequence): 403 return map(func, vars, _partition_outputs([v.aval for v in vars], nodes)) 404 405def jaxpr_subcomp(c, jaxpr, backend, axis_env, consts, name_stack, *args): 406 if backend not in ('cpu', 'gpu', 'tpu'): 407 platform = xb.get_backend(backend).platform # canonicalize 408 else: 409 platform = backend 410 411 def read(v): 412 if type(v) is Literal: 413 return [xb.constant(c, canonicalize_dtype(v.val))] 414 else: 415 return env[v] 416 417 def aval(v): 418 if type(v) is Literal: 419 return abstractify(v.val) 420 else: 421 return v.aval 422 423 def write(v, node): 424 assert node is not None 425 env[v] = node 426 427 env = {} 428 _partitionmap(write, [core.unitvar], [_make_unit_constant(c)]) 429 _partitionmap(write, jaxpr.constvars, consts) 430 _partitionmap(write, jaxpr.invars, args) 431 for eqn in jaxpr.eqns: 432 frame = source_info_util.user_frame(eqn.source_info) 433 c.set_op_metadata(xc.OpMetadata( 434 op_type=eqn.primitive.name, 435 op_name=str(pp(name_stack) >> pp_eqn_compact( 436 eqn.primitive.name, eqn.params)), 437 source_file=frame.file_name if frame else None, 438 source_line=frame.line_num if frame else None)) 439 in_nodes = _flatmap(read, eqn.invars) 440 # TODO(jakevdp): migrate `translations` table to `translations_with_avals` 441 if eqn.primitive in backend_specific_translations[platform]: 442 rule = backend_specific_translations[platform][eqn.primitive] 443 ans = rule(c, *in_nodes, **eqn.params) 444 elif eqn.primitive in translations: 445 ans = translations[eqn.primitive](c, *in_nodes, **eqn.params) 446 elif eqn.primitive in translations_with_avals: 447 rule = translations_with_avals[eqn.primitive] 448 ans = rule(c, map(aval, eqn.invars), in_nodes, eqn.params) 449 elif eqn.primitive in initial_style_translations: 450 new_params = check_backend_params(eqn.params, backend) 451 rule = initial_style_translations[eqn.primitive] 452 ans = rule(c, axis_env, extend_name_stack(name_stack, eqn.primitive.name), 453 map(aval, eqn.invars), backend, *in_nodes, **new_params) 454 elif eqn.primitive in parallel_translations: 455 rule = parallel_translations[eqn.primitive] 456 ans = rule(c, *in_nodes, axis_env=axis_env, platform=platform, **eqn.params) 457 elif eqn.primitive in call_translations: 458 new_params = check_backend_params(eqn.params, backend) 459 rule = call_translations[eqn.primitive] 460 ans = rule(c, axis_env, in_nodes, 461 name_stack, backend=backend, **new_params) 462 else: 463 raise NotImplementedError( 464 f"XLA translation rule for primitive '{eqn.primitive.name}' not found") 465 466 assert isinstance(ans, xe.XlaOp) 467 c.get_shape(ans) # force xla to do shape error checking 468 if eqn.primitive.multiple_results or any(v.aval._num_buffers > 1 for v in eqn.outvars): 469 out_nodes = xla_destructure(c, ans) 470 else: 471 out_nodes = [ans] 472 c.clear_op_metadata() 473 _partitionmap(write, eqn.outvars, out_nodes) 474 return _flatmap(read, jaxpr.outvars) 475 476 477def xla_destructure(c, ans): 478 num_elements = len(c.get_shape(ans).tuple_shapes()) 479 return [xops.GetTupleElement(ans, i) for i in range(num_elements)] 480 481def check_backend_params(params, outer_backend): 482 # For nested calls, the outermost call sets the backend for all inner calls; 483 # it's an error if the inner call has a conflicting explicit backend spec. 484 inner_backend = params.get('backend', None) 485 if inner_backend and inner_backend != outer_backend: 486 raise ValueError( 487 f"Outer-jit backend specification {outer_backend} must match explicit " 488 f"inner-jit backend specification {inner_backend}.") 489 return {k: params[k] for k in params if k != 'backend'} 490 491 492class AxisEnv(NamedTuple): 493 """Represents a pmap mesh (only along the replica axes).""" 494 nreps: int 495 names: Tuple[Any, ...] 496 sizes: Tuple[int, ...] 497 498def extend_axis_env(env: AxisEnv, name, size: int): 499 return AxisEnv(env.nreps, env.names + (name,), env.sizes + (size,)) 500 501def axis_read(axis_env, axis_name): 502 try: 503 return max(i for i, name in enumerate(axis_env.names) if name == axis_name) 504 except ValueError: 505 raise NameError("unbound axis name: {}".format(axis_name)) from None 506 507def axis_groups(axis_env: AxisEnv, name): 508 if not isinstance(name, (list, tuple)): 509 name = (name,) 510 mesh_axes = tuple(unsafe_map(partial(axis_read, axis_env), name)) 511 trailing_size, ragged = divmod(axis_env.nreps, prod(axis_env.sizes)) 512 assert not ragged 513 mesh_spec = axis_env.sizes + (trailing_size,) 514 return _axis_groups(mesh_spec, mesh_axes) 515 516def _axis_groups(mesh_spec, mesh_axes): 517 """Computes replica group ids for a collective performed over a subset of the mesh. 518 519 Args: 520 mesh_spec: A sequence of integers representing the mesh shape. 521 mesh_axes: A sequence of integers between 0 and `len(mesh_spec)` (exclusive) 522 indicating over which axes the collective is performed. 523 Returns: 524 A tuple of replica groups (i.e. tuples containing replica ids). 525 """ 526 iota = np.arange(prod(mesh_spec)).reshape(mesh_spec) 527 groups = np.reshape( 528 np.moveaxis(iota, mesh_axes, np.arange(len(mesh_axes))), 529 (prod(np.take(mesh_spec, mesh_axes)), -1)) 530 return tuple(unsafe_map(tuple, groups.T)) 531 532def jaxpr_replicas(jaxpr: core.Jaxpr) -> int: 533 """The number of replicas needed for a jaxpr. 534 535 For a eqn, multiply the `axis_size` with the `jaxpr_replicas` of the 536 subjaxprs. For a list of eqns, take the maximum number of replicas. 537 """ 538 return max(unsafe_map(eqn_replicas, jaxpr.eqns), default=1) 539 540# TODO(mattjj): this function assumes that only pmap has a parameter named 541# axis_size, and that it corresponds to cross-replica mapping 542def eqn_replicas(eqn): 543 call_jaxpr = eqn.params.get("call_jaxpr") 544 if call_jaxpr: 545 return eqn.params.get('axis_size', 1) * jaxpr_replicas(call_jaxpr) 546 elif eqn.primitive in initial_style_translations: 547 return initial_style_primitive_replicas(eqn.params) 548 else: 549 return 1 550 551def initial_style_primitive_replicas(params): 552 return max(core.traverse_jaxpr_params(jaxpr_replicas, params), default=1) 553 554# TODO(mattjj,skyewm): the functions here are utilities for checking if 555# not-yet-supported features are used with multi-host programming 556 557def jaxpr_has_pmap(jaxpr): 558 """Whether there is an xla_pmap primitive anywhere inside a Jaxpr.""" 559 for eqn in jaxpr.eqns: 560 if 'xla_pmap' in eqn.primitive.name: 561 return True 562 for subjaxpr in core.subjaxprs(jaxpr): 563 if jaxpr_has_pmap(subjaxpr): 564 return True 565 return False 566 567 568def jaxpr_collectives(jaxpr): 569 """Generates all the collective primitives anywhere inside a Jaxpr.""" 570 for eqn in jaxpr.eqns: 571 if eqn.primitive in parallel_translations: 572 yield eqn.primitive 573 for subjaxpr in core.subjaxprs(jaxpr): 574 yield from jaxpr_collectives(subjaxpr) 575 576 577### xla_call underlying jit 578 579def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars): 580 compiled_fun = _xla_callable(fun, device, backend, name, donated_invars, 581 *unsafe_map(arg_spec, args)) 582 try: 583 return compiled_fun(*args) 584 except FloatingPointError: 585 assert FLAGS.jax_debug_nans or FLAGS.jax_debug_infs # compiled_fun can only raise in this case 586 print("Invalid value encountered in the output of a jit function. " 587 "Calling the de-optimized version.") 588 # We want to run the wrapped function again (after _xla_callable already ran 589 # it), but linear_util.WrappedFun instances are meant to be run only once. 590 # In addition to re-executing the Python code, which is usually undesirable 591 # but which FLAGS.jax_debug_nans is meant to opt into, we'll be re-executing 592 # any linear_util.py-style side effects, i.e. re-populating Stores created 593 # by any transformation_with_aux's applied to fun. Since this is 594 # intentional here, to avoid "Store occupied" errors we reset the stores to 595 # be empty. 596 for store in fun.stores: store and store.reset() 597 return fun.call_wrapped(*args) # probably won't return 598 599def flatten_shape(s: XlaShape) -> Sequence[Tuple[Sequence[int], XlaShape]]: 600 """Expands a given shape tree into a flat list of indices to arrays. 601 602 Given the following computation: 603 604 >>> c = xc.XlaBuilder("example") 605 >>> p0 = xb.parameter(c, 1, xc.shape_from_pyval(jnp.ones([1]))) 606 >>> p1 = xb.parameter(c, 2, xc.shape_from_pyval(jnp.ones([2]))) 607 >>> p2 = xb.parameter(c, 3, xc.shape_from_pyval(jnp.ones([3]))) 608 >>> o = xops.Tuple(c, [p0, p1, p2]) 609 610 We can query the arrays in the output tuple: 611 612 >>> flatten_shape(c.GetShape(o)) 613 (((0,), f32[1]{0}), 614 ((1,), f32[2]{0}), 615 ((2,), f32[3]{0})) 616 617 Or the arrays in one of the parameters (which is itself an array): 618 619 >>> flatten_shape(c.GetShape(p0)) 620 (((), f32[1]{0}),) 621 622 Args 623 s: The input shape. 624 625 Returns: 626 An iterable of pairs of indices and shapes for each array within the shape 627 tree. 628 """ 629 def _flatten_shape(s, index): 630 if s.is_array(): 631 yield index, s 632 else: 633 assert s.is_tuple() 634 for i, sub in enumerate(s.tuple_shapes()): 635 subindex = index + (i,) 636 if sub.is_tuple(): 637 yield from _flatten_shape(sub, subindex) 638 else: 639 yield subindex, sub 640 return tuple(_flatten_shape(s, index=())) 641 642def _xla_consts(c, consts): 643 unique_consts = {id(const): const for const in consts} 644 xla_consts = { 645 id_: xb.constant(c, const) for id_, const in unique_consts.items()} 646 return [xla_consts[id(const)] for const in consts] 647 648@lu.cache 649def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *arg_specs): 650 if device is not None and backend is not None: 651 raise ValueError("can't specify both a device and a backend for jit, " 652 "got device={} and backend={}".format(device, backend)) 653 654 abstract_args, arg_devices = unzip2(arg_specs) 655 if config.omnistaging_enabled: 656 jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args) 657 if any(isinstance(c, core.Tracer) for c in consts): 658 raise core.UnexpectedTracerError("Encountered an unexpected tracer.") 659 else: 660 pvals: Sequence[pe.PartialVal] = [pe.PartialVal.unknown(aval) for aval in abstract_args] 661 jaxpr, pvals, consts = pe.trace_to_jaxpr( # type: ignore 662 fun, pvals, instantiate=False, stage_out=True, bottom=True) # type: ignore 663 map(prefetch, it.chain(consts, jaxpr_literals(jaxpr))) 664 jaxpr = apply_outfeed_rewriter(jaxpr) 665 666 nreps = jaxpr_replicas(jaxpr) 667 device = _xla_callable_device(nreps, backend, device, arg_devices) 668 backend = device.platform if device else backend 669 if config.omnistaging_enabled: 670 result_handlers = map(partial(aval_to_result_handler, device), out_avals) 671 else: 672 out_avals = [pval.get_aval() for pval in pvals] 673 result_handlers = map(partial(_pval_to_result_handler, device), pvals) # type: ignore 674 675 # Computations that only produce constants and/or only rearrange their inputs, 676 # which are often produced from partial evaluation, don't need compilation, 677 # and don't need to force their (potentially lazy) arguments. 678 if not jaxpr.eqns: 679 return partial(_execute_trivial, jaxpr, device, consts, out_avals, result_handlers) 680 681 if not _on_exit: 682 log_priority = logging.WARNING if FLAGS.jax_log_compiles else logging.DEBUG 683 logging.log(log_priority, "Compiling %s for args %s.", fun.__name__, abstract_args) 684 685 if nreps > 1: 686 warn(f"The jitted function {fun.__name__} includes a pmap. Using " 687 "jit-of-pmap can lead to inefficient data movement, as the outer jit " 688 "does not preserve sharded data representations and instead collects " 689 "input and output arrays onto a single device. " 690 "Consider removing the outer jit unless you know what you're doing. " 691 "See https://github.com/google/jax/issues/2926.") 692 693 if nreps > xb.device_count(backend): 694 raise ValueError( 695 f"compiling computation that requires {nreps} replicas, but only " 696 f"{xb.device_count(backend)} XLA devices are available") 697 698 if xb.host_count() > 1 and (nreps > 1 or jaxpr_has_pmap(jaxpr)): 699 raise NotImplementedError( 700 "jit of multi-host pmap not implemented (and jit-of-pmap can cause " 701 "extra data movement anyway, so maybe you don't want it after all).") 702 703 tuple_args = len(abstract_args) > 100 # pass long arg lists as tuple for TPU 704 705 c = xb.make_computation_builder("jit_{}".format(fun.__name__)) 706 xla_consts = _xla_consts(c, consts) 707 xla_args, donated_invars = _xla_callable_args(c, abstract_args, tuple_args, donated_invars=donated_invars) 708 out_nodes = jaxpr_subcomp( 709 c, jaxpr, backend, AxisEnv(nreps, (), ()), xla_consts, 710 extend_name_stack(wrap_name(name, 'jit')), *xla_args) 711 out_tuple = xops.Tuple(c, out_nodes) 712 backend = xb.get_backend(backend) 713 if backend.platform in ("gpu", "tpu"): 714 donated_invars = set_up_aliases(c, xla_args, out_tuple, donated_invars, tuple_args) 715 if any(donated_invars): 716 # TODO(tomhennigan): At call time we should mark these buffers as deleted. 717 unused_donations = [str(c.GetShape(a)) 718 for a, d in zip(xla_args, donated_invars) if d] 719 warn("Some donated buffers were not usable: {}".format(", ".join(unused_donations))) 720 built = c.build(out_tuple) 721 722 options = xb.get_compile_options( 723 num_replicas=nreps, 724 num_partitions=1, 725 device_assignment=(device.id,) if device else None) 726 options.parameter_is_tupled_arguments = tuple_args 727 compiled = backend_compile(backend, built, options) 728 if nreps == 1: 729 return partial(_execute_compiled, compiled, out_avals, result_handlers) 730 else: 731 return partial(_execute_replicated, compiled, out_avals, result_handlers) 732 733def set_up_aliases(c, xla_args, out_tuple, donated_args, tuple_args): 734 """Configures input/output "must" aliasing based on `donated_args`.""" 735 # First for every input array add it to `donations` iff it is a member of 736 # `donated_args`. 737 donations = defaultdict(deque) 738 for arg_index, arg in enumerate(xla_args): 739 if donated_args[arg_index]: 740 for param_index, element in flatten_shape(c.GetShape(arg)): 741 key = (element.dimensions(), element.numpy_dtype()) 742 if tuple_args: 743 param_number = 0 744 param_index = (arg_index,) + tuple(param_index) 745 donations[key].append((param_number, param_index, arg_index)) 746 else: 747 param_number = arg_index 748 donations[key].append((param_number, param_index, arg_index)) 749 750 # Consume donations for outputs. 751 out_donated_args = list(donated_args) 752 for output_index, element in flatten_shape(c.GetShape(out_tuple)): 753 key = (element.dimensions(), element.numpy_dtype()) 754 if donations.get(key, ()): 755 param_number, param_index, arg_index = donations[key].popleft() 756 out_donated_args[arg_index] = False 757 c.setup_alias(output_index, param_number, param_index) 758 759 return tuple(out_donated_args) 760 761def _xla_callable_device(nreps, backend, device, arg_devices): 762 if nreps > 1: 763 if device is not None or backend is not None: 764 raise ValueError(f"can't specify device or backend for jit-of-pmap, " 765 f"got device={device} and backend={backend}") 766 return None 767 else: 768 if device is None and backend is None: 769 return _device_from_arg_devices(arg_devices) 770 elif device is not None and backend is None: 771 return device 772 elif device is None and backend is not None: 773 return xb.get_backend(backend).get_default_device_assignment(1)[0] 774 else: 775 assert False # Unreachable given the error check in _xla_callable 776 777# Used within _xla_callable_args and _xla_param to distinguish between None (no 778# sharding annotation set) and replicated. 779_replicated_param = object() 780 781def _xla_callable_args( 782 c, avals, tuple_args, *, 783 replicated=None, 784 partitions=None, 785 partitions_proto: bool = False, 786 donated_invars=None): 787 assert partitions is None or len(partitions) == len(avals) 788 if not tuple_args: 789 if replicated is None: 790 replicated = [None] * len(avals) 791 if partitions is None: 792 parts: List[object] = [None] * len(avals) 793 elif partitions_proto: 794 parts = partitions 795 else: 796 parts = [_replicated_param if part is None else part 797 for part in partitions] 798 counts = it.count() 799 xla_args = [_xla_param(c, next(counts), xla_shape, r, p, partitions_proto) 800 if a is not abstract_token else xops.CreateToken(c) 801 for (a, r, p) in safe_zip(avals, replicated, parts) 802 for xla_shape in aval_to_xla_shapes(a)] 803 if donated_invars is not None: 804 donated_invars = [d 805 for (a, r, p, d) in safe_zip(avals, replicated, parts, donated_invars) 806 for xla_shape in aval_to_xla_shapes(a)] 807 return xla_args, donated_invars 808 else: 809 if replicated is not None: 810 replicated = [r for a, r in zip(avals, replicated) 811 if a is not abstract_token] 812 if partitions is None: 813 tuple_parts = None 814 elif partitions_proto: 815 tuple_parts = xb.tuple_sharding_proto(partitions) 816 else: 817 tuple_parts = tuple(partitions) 818 tuple_shape = xc.Shape.tuple_shape( 819 [shape for a in avals for shape in aval_to_xla_shapes(a) if a is not abstract_token]) 820 tuple_param = _xla_param(c, 0, tuple_shape, replicated, tuple_parts, partitions_proto) 821 xla_inputs = iter(xla_destructure(c, tuple_param)) 822 xla_args = [next(xla_inputs) if a is not abstract_token else 823 xops.CreateToken(c) for a in avals] 824 assert next(xla_inputs, None) is None 825 return xla_args, donated_invars 826 827def _xla_param(builder, param_num, xla_shape, replicated, partitions, parts_proto): 828 make_param = partial(xb.parameter, builder, param_num, xla_shape, 829 replicated=replicated) 830 with_sharding = xb.with_sharding_proto if parts_proto else xb.with_sharding 831 if partitions is None: 832 return make_param() 833 elif partitions is _replicated_param: 834 return with_sharding(builder, None, make_param) 835 else: 836 return with_sharding(builder, partitions, make_param) 837 838def _execute_compiled(compiled: XlaExecutable, avals, handlers, *args): 839 device, = compiled.local_devices() 840 input_bufs = list(it.chain.from_iterable(device_put(x, device) for x in args if x is not token)) 841 out_bufs = compiled.execute(input_bufs) 842 check_special(xla_call_p, out_bufs) 843 return [handler(*bs) for handler, bs in zip(handlers, _partition_outputs(avals, out_bufs))] 844 845def _execute_replicated(compiled: XlaExecutable, avals, handlers, *args): 846 input_bufs = [ 847 list(it.chain.from_iterable(device_put(x, device) for x in args if x is not token)) 848 for device in compiled.local_devices()] 849 out_bufs = compiled.execute_on_local_devices(input_bufs)[0] 850 check_special(xla_call_p, out_bufs) 851 return [handler(*bs) for handler, bs in zip(handlers, _partition_outputs(avals, out_bufs))] 852 853def _execute_trivial(jaxpr, device: Optional[Device], consts, avals, handlers, *args): 854 env = {core.unitvar: core.unit} 855 map(env.setdefault, jaxpr.invars, args) 856 map(env.setdefault, jaxpr.constvars, consts) 857 outs = [canonicalize_dtype(v.val) if type(v) is Literal else env[v] 858 for v in jaxpr.outvars] 859 return [_copy_device_array_to_device(x, device) if type_is_device_array(x) 860 else h(*device_put(x, device)) for h, x in zip(handlers, outs)] 861 862xla_call_p = core.CallPrimitive('xla_call') 863xla_call = xla_call_p.bind 864xla_call_p.def_impl(_xla_call_impl) 865 866def _xla_call_partial_eval_update_params(params, in_unknowns): 867 call_jaxpr = params['call_jaxpr'] 868 donated_invars = params['donated_invars'] 869 if not in_unknowns and donated_invars: 870 # JaxprTrace.post_process_call creates a call with no input tracers 871 new_donated_invars = (False,) * len(call_jaxpr.invars) 872 else: 873 # JaxprTrace.process_call drops known input tracers 874 donated_invars = [d for d, uk in zip(donated_invars, in_unknowns) if uk] 875 new_donated_invars = ((False,) * (len(call_jaxpr.invars) - len(donated_invars)) 876 + tuple(donated_invars)) 877 return dict(params, donated_invars=new_donated_invars) 878pe.call_param_updaters[xla_call_p] = _xla_call_partial_eval_update_params 879 880def _xla_call_jvp_update_params(params, nz_tangents): 881 donated_invars = params['donated_invars'] 882 donated_tangents = [d for d, nz in zip(donated_invars, nz_tangents) if nz] 883 new_donated_invars = (*donated_invars, *donated_tangents) 884 return dict(params, donated_invars=new_donated_invars) 885ad.call_param_updaters[xla_call_p] = _xla_call_jvp_update_params 886 887def _xla_call_transpose_update_params(params, undef_primals, nonzero_cts): 888 donated_invars = params['donated_invars'] 889 donated_primals = [d for d, u in zip(donated_invars, undef_primals) if not u] 890 donated_cotangents = [False for nz in nonzero_cts if nz] 891 return dict(params, donated_invars=(*donated_primals, *donated_cotangents)) 892ad.call_transpose_param_updaters[xla_call_p] = _xla_call_transpose_update_params 893 894 895def _xla_call_translation_rule(c, axis_env, 896 in_nodes, name_stack, backend, name, 897 call_jaxpr, donated_invars, device=None): 898 del device, donated_invars # Ignored. 899 subc = xb.make_computation_builder(f"jit_{name}") 900 args = [xb.parameter(subc, i, c.get_shape(n)) for i, n in enumerate(in_nodes)] 901 out_nodes = jaxpr_subcomp(subc, call_jaxpr, backend, axis_env, (), 902 extend_name_stack(name_stack, wrap_name(name, 'jit')), *args) 903 subc = subc.build(xops.Tuple(subc, out_nodes)) 904 return xops.Call(c, subc, list(in_nodes)) 905ad.primitive_transposes[xla_call_p] = partial(ad.call_transpose, xla_call_p) 906 907 908### translation tables 909 910translations: Dict[core.Primitive, Callable] = {} 911translations_with_avals: Dict[core.Primitive, Callable] = {} 912parallel_translations: Dict[core.Primitive, Callable] = {} 913initial_style_translations: Dict[core.Primitive, Callable] = {} 914call_translations: Dict[core.Primitive, Callable] = {} 915backend_specific_translations: Dict[str, Dict[core.Primitive, Callable]] = defaultdict(dict) 916 917call_translations[xla_call_p] = _xla_call_translation_rule 918 919def zeros_like_translation_rule(c, x): 920 shape = c.get_shape(x) 921 assert not shape.is_tuple() 922 zero = xb.constant(c, np.array(0, shape.element_type())) 923 return xops.Broadcast(zero, shape.dimensions()) 924translations[ad_util.zeros_like_p] = zeros_like_translation_rule 925 926def add_jaxvals_translation_rule(c, x, y): 927 shape = c.get_shape(x) 928 assert not shape.is_tuple() 929 return xops.Add(x, y) 930translations[ad_util.add_jaxvals_p] = add_jaxvals_translation_rule 931 932translations[ad_util.stop_gradient_p] = lambda c, x: x 933 934 935@lu.transformation 936def _tuple_output(*args, **kwargs): 937 ans = yield args, kwargs 938 yield (ans,) 939 940def lower_fun(fun, multiple_results, parallel=False, with_avals=False): 941 # TODO(jakevdp): migrate dependent code & always use the with_avals=True. 942 def f(c, *xla_args, **params): 943 avals = [_array_aval_from_xla_shape(c.get_shape(x)) for x in xla_args] 944 return f_with_avals(c, avals, xla_args, params) 945 946 def f_with_avals(c, avals, xla_args, params): 947 if parallel: 948 axis_env = params.pop('axis_env') 949 del params['platform'] 950 else: 951 axis_env = AxisEnv(1, (), ()) 952 wrapped_fun = lu.wrap_init(fun, params) 953 if not multiple_results: 954 wrapped_fun = _tuple_output(wrapped_fun) 955 if config.omnistaging_enabled: 956 jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals) 957 outs = jaxpr_subcomp(c, jaxpr, None, axis_env, _xla_consts(c, consts), '', 958 *xla_args) 959 else: 960 pvals = [pe.PartialVal.unknown(a) for a in avals] 961 jaxpr, _, consts = pe.trace_to_jaxpr(wrapped_fun, pvals, instantiate=True, 962 stage_out=True) # type: ignore 963 xla_consts = _xla_consts(c, consts) 964 outs = jaxpr_subcomp(c, jaxpr, None, axis_env, xla_consts, '', *xla_args) 965 if multiple_results or any(v.aval._num_buffers > 1 for v in jaxpr.outvars): 966 return xops.Tuple(c, outs) 967 else: 968 assert len(outs) == 1, outs 969 return outs[0] 970 971 return f_with_avals if with_avals else f 972 973def _array_aval_from_xla_shape(xla_shape): 974 # This function instantiates the assumption that we can map fro XLA array 975 # types to JAX array types. 976 # TODO(mattjj): remove assumption can map XLA array types to JAX array types 977 assert not xla_shape.is_tuple() 978 return ShapedArray(xla_shape.dimensions(), xla_shape.numpy_dtype()) 979 980def lower_fun_initial_style(fun): 981 def f(c, axis_env, name_stack, avals, backend, *xla_args, **params): 982 if config.omnistaging_enabled: 983 jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(lu.wrap_init(fun, params), avals) 984 outs = jaxpr_subcomp(c, jaxpr, backend, axis_env, _xla_consts(c, consts), 985 name_stack, *xla_args) 986 else: 987 pvals = [pe.PartialVal.unknown(a) for a in avals] 988 jaxpr, _, consts = pe.trace_to_jaxpr( 989 lu.wrap_init(fun, params), pvals, instantiate=True, stage_out=True) # type: ignore 990 xla_consts = _xla_consts(c, consts) 991 outs = jaxpr_subcomp(c, jaxpr, backend, axis_env, xla_consts, name_stack, 992 *xla_args) 993 return xops.Tuple(c, outs) 994 return f 995 996 997### device-persistent data 998 999class Token(object): pass 1000token = Token() 1001 1002pytype_aval_mappings[Token] = lambda _: abstract_token 1003core.pytype_aval_mappings[Token] = lambda _: abstract_token 1004xla_shape_handlers[AbstractToken] = lambda _: (xc.Shape.token_shape(),) 1005xla_result_handlers[AbstractToken] = lambda _, __: lambda _: token 1006canonicalize_dtype_handlers[Token] = identity 1007 1008 1009def _forward_method(attrname, self, fun, *args): 1010 return fun(getattr(self, attrname), *args) 1011_forward_to_value = partial(_forward_method, "_value") 1012 1013 1014# The following is used for the type _CppDeviceArray or _DeviceArray. 1015DeviceArrayProtocol = Any 1016if hasattr(xc, "DeviceArrayBase"): 1017 DeviceArray = xc.DeviceArrayBase 1018else: 1019 # prior to jaxlib version 0.1.58. 1020 class DeviceArray: # type: ignore 1021 pass 1022 1023_CppDeviceArray: DeviceArrayProtocol = xc.Buffer 1024 1025_EXPERIMENTAL_CPP_DEVICE_ARRAY = False 1026 1027 1028def make_device_array( 1029 aval: core.ShapedArray, 1030 device: Optional[Device], 1031 lazy_expr: Optional[lazy.LazyExpr], 1032 device_buffer: Union[PyLocalBuffer, "DeviceConstant"], 1033) -> Union[PyLocalBuffer, "_DeviceArray"]: 1034 """Returns a DeviceArray implementation based on arguments. 1035 1036 This is to be used only within JAX. It will return either a PythonDeviceArray 1037 or a C++ equivalent implementation. 1038 """ 1039 if (_EXPERIMENTAL_CPP_DEVICE_ARRAY and lazy.is_trivial(lazy_expr) and 1040 not isinstance(device_buffer, DeviceConstant)): 1041 assert isinstance(device_buffer, _CppDeviceArray) 1042 device_buffer._device = device # pylint: disable=protected-access 1043 device_buffer.aval = aval 1044 return device_buffer 1045 1046 return _DeviceArray(aval, device, lazy_expr, device_buffer) 1047 1048 1049def type_is_device_array(x): 1050 """Returns `True` if `x` is a non-sharded DeviceArray. 1051 1052 Use this function instead of `type(x) is Devicearray`. 1053 """ 1054 type_x = type(x) 1055 return type_x is _DeviceArray or type_x is _CppDeviceArray 1056 1057 1058class _DeviceArray(DeviceArray): # type: ignore 1059 """A DeviceArray is an ndarray backed by a single device memory buffer.""" 1060 # We don't subclass ndarray because that would open up a host of issues, 1061 # but lax_numpy.py overrides isinstance behavior and attaches ndarray methods. 1062 __slots__ = [ 1063 "aval", "device_buffer", "_npy_value", "_device", "_lazy_expr" 1064 ] 1065 __array_priority__ = 100 1066 1067 # DeviceArray has methods that are dynamically populated in lax_numpy.py, 1068 # and this annotation is needed to make pytype happy. 1069 _HAS_DYNAMIC_ATTRIBUTES = True 1070 1071 def __init__(self, aval: core.ShapedArray, device: Optional[Device], 1072 lazy_expr: Optional[lazy.LazyExpr], 1073 device_buffer: PyLocalBuffer): 1074 """Initializer. 1075 1076 Args: 1077 aval: The abstract value associated to this array (shape+dtype+weak_type). 1078 device: The optional sticky device. See 1079 https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices 1080 lazy_expr: An optional `LayExpr`. `None` is equivalent to a trivial 1081 `LazyExpr`. 1082 device_buffer: The underlying buffer owning the on-device data. 1083 """ 1084 DeviceArray.__init__(self) 1085 self.aval = aval 1086 self.device_buffer = device_buffer 1087 self._device = device 1088 self._lazy_expr = lazy_expr 1089 1090 self._npy_value = None 1091 if not core.skip_checks: 1092 assert type(aval) is ShapedArray 1093 npy_value = self._value 1094 assert npy_value.dtype == aval.dtype and npy_value.shape == aval.shape 1095 assert (device is None) or device is device_buffer.device() 1096 1097 def _check_if_deleted(self): 1098 if self.device_buffer is deleted_buffer: 1099 raise RuntimeError("DeviceArray has been deleted.") 1100 1101 def block_until_ready(self): 1102 """Blocks the caller until the buffer's value has been computed on device. 1103 1104 This method is mostly useful for timing microbenchmarks that wish to 1105 time how long a computation takes, without transferring the result back 1106 to the host. 1107 1108 Returns the buffer object (`self`). 1109 """ 1110 self._check_if_deleted() 1111 self.device_buffer.block_host_until_ready() # pytype: disable=attribute-error 1112 return self 1113 1114 @property 1115 def _value(self): 1116 self._check_if_deleted() 1117 if self._npy_value is None: 1118 if is_device_constant(self): 1119 self._npy_value = lazy.eval_lexpr(self._lazy_expr, None) 1120 else: 1121 self._npy_value = _force(self).device_buffer.to_py() 1122 self._npy_value.flags.writeable = False 1123 return self._npy_value 1124 1125 @property 1126 def shape(self): 1127 return self.aval.shape 1128 1129 @property 1130 def dtype(self): 1131 return self.aval.dtype 1132 1133 @property 1134 def size(self): 1135 return prod(self.aval.shape) 1136 1137 @property 1138 def ndim(self): 1139 return len(self.aval.shape) 1140 1141 def copy_to_host_async(self): 1142 """Requests a copy of the buffer to the host.""" 1143 self._check_if_deleted() 1144 if self._npy_value is None and not is_device_constant(self): 1145 self.device_buffer.copy_to_host_async() # pytype: disable=attribute-error 1146 1147 def delete(self): 1148 """Deletes the device array and any cached copy on the host. 1149 1150 It is an error to access the contents of a `DeviceArray` after it has 1151 been deleted. 1152 1153 Use of this method is optional; device buffers will be reclaimed 1154 automatically by Python when a DeviceArray object is garbage collected. 1155 However, it is sometimes useful to have more explicit control over the 1156 time of deletion. 1157 """ 1158 self.device_buffer.delete() # pytype: disable=attribute-error 1159 self.device_buffer = deleted_buffer 1160 self._npy_value = None 1161 1162 @property 1163 def __cuda_array_interface__(self): 1164 return _force(self).device_buffer.__cuda_array_interface__ 1165 1166 1167# Adding methods dynamically to both _DeviceArray and _CppDeviceArray 1168# pylint: disable=protected-access 1169for device_array in [_DeviceArray, _CppDeviceArray]: 1170 1171 1172 def copy(self): 1173 """Returns an ndarray (backed by host memory, not device memory).""" 1174 return np.asarray(self) 1175 setattr(device_array, "copy", copy) 1176 1177 def __repr__(self): 1178 line_width = np.get_printoptions()["linewidth"] 1179 prefix = '{}('.format(self.__class__.__name__.lstrip('_')) 1180 s = np.array2string(self._value, prefix=prefix, suffix=',', 1181 separator=', ', max_line_width=line_width) 1182 dtype_str = 'dtype={})'.format(self.dtype.name) 1183 last_line_len = len(s) - s.rfind('\n') + 1 1184 sep = ' ' 1185 if last_line_len + len(dtype_str) + 1 > line_width: 1186 sep = ' ' * len(prefix) 1187 return "{}{},{}{}".format(prefix, s, sep, dtype_str) 1188 1189 setattr(device_array, "__repr__", __repr__) 1190 1191 def item(self): 1192 if dtypes.issubdtype(self.dtype, np.complexfloating): 1193 return complex(self) 1194 elif dtypes.issubdtype(self.dtype, np.floating): 1195 return float(self) 1196 elif dtypes.issubdtype(self.dtype, np.integer): 1197 return int(self) 1198 elif dtypes.issubdtype(self.dtype, np.bool_): 1199 return bool(self) 1200 else: 1201 raise TypeError(self.dtype) 1202 1203 setattr(device_array, "item", item) 1204 1205 def __len__(self): 1206 try: 1207 return self.aval.shape[0] 1208 except IndexError as err: 1209 raise TypeError("len() of unsized object") from err # same as numpy error 1210 1211 setattr(device_array, "__len__", __len__) 1212 1213 def __iter__(self): 1214 if self.ndim == 0: 1215 raise TypeError("iteration over a 0-d array") # same as numpy error 1216 else: 1217 return self._value.__iter__() 1218 1219 setattr(device_array, "__iter__", __iter__) 1220 1221 def __reversed__(self): 1222 if self.ndim == 0: 1223 raise TypeError("iteration over a 0-d array") 1224 else: 1225 return reversed(self._value) 1226 1227 setattr(device_array, "__reversed__", __reversed__) 1228 1229 def __format__(self, format_spec): 1230 # Simulates behavior of https://github.com/numpy/numpy/pull/9883 1231 if self.ndim == 0: 1232 return format(self._value[()], format_spec) 1233 else: 1234 return format(self._value, format_spec) 1235 1236 setattr(device_array, "__format__", __format__) 1237 1238 def __array__(self, dtype=None, context=None): 1239 return np.asarray(self._value, dtype=dtype) 1240 1241 setattr(device_array, "__array__", __array__) 1242 1243 setattr(device_array, "__str__", partialmethod(_forward_to_value, str)) 1244 setattr(device_array, "__bool__", partialmethod(_forward_to_value, bool)) 1245 setattr(device_array, "__nonzero__", partialmethod(_forward_to_value, bool)) 1246 setattr(device_array, "__float__", lambda self: self._value.__float__()) 1247 setattr(device_array, "__int__", lambda self: self._value.__int__()) 1248 setattr(device_array, "__complex__", lambda self: self._value.__complex__()) 1249 setattr(device_array, "__hex__", partialmethod(_forward_to_value, hex)) 1250 setattr(device_array, "__oct__", partialmethod(_forward_to_value, oct)) 1251 setattr(device_array, "__index__", partialmethod(_forward_to_value, op.index)) 1252 to_bytes = lambda self, order="C": self._value.tobytes(order) 1253 setattr(device_array, "tobytes", to_bytes) 1254 del to_bytes 1255 setattr(device_array, "tolist", lambda self: self._value.tolist()) 1256 1257 # pickle saves and loads just like an ndarray 1258 setattr(device_array, "__reduce__", 1259 partialmethod(_forward_to_value, op.methodcaller("__reduce__"))) 1260 1261 # clobbered when jax.numpy is imported, but useful in tests 1262 setattr(device_array, "__eq__", lambda self, other: self._value == other) 1263 1264 def __hash__(self): 1265 raise TypeError("JAX DeviceArray, like numpy.ndarray, is not hashable.") 1266 1267 setattr(device_array, "__hash__", __hash__) 1268 1269 # The following methods are dynamically overridden in lax_numpy.py. 1270 def raise_not_implemented(): 1271 raise NotImplementedError 1272 1273 setattr(device_array, "__getitem__", lambda self, i: raise_not_implemented()) 1274# pylint: enable=protected-access 1275 1276 1277class DeletedBuffer(object): pass 1278deleted_buffer = DeletedBuffer() 1279 1280class DeviceConstant(object): 1281 __slots__ = ["_device"] 1282 def __init__(self, device=None): self._device = device 1283 def device(self): return self._device 1284 def to_py(self): return None 1285 1286def is_device_constant(x): 1287 return type_is_device_array(x) and type(x.device_buffer) is DeviceConstant 1288 1289for device_array in [_CppDeviceArray, _DeviceArray]: 1290 core.literalable_types.add(device_array) 1291 core.pytype_aval_mappings[device_array] = ConcreteArray 1292 pytype_aval_mappings[device_array] = op.attrgetter('aval') 1293 canonicalize_dtype_handlers[device_array] = identity 1294 1295def _device_array_constant_handler(c, val, canonicalize_types=True): 1296 if is_device_constant(val): 1297 return lazy.stage_lexpr(c, val._lazy_expr, None) 1298 else: 1299 base_val = xb.constant(c, val.device_buffer.to_py()) 1300 return lazy.stage_lexpr(c, val._lazy_expr, base_val) 1301xb.register_constant_handler(_DeviceArray, _device_array_constant_handler) 1302xb.register_constant_handler(_CppDeviceArray, _device_array_constant_handler) 1303 1304def _device_put_device_array(x: Union[DeviceArrayProtocol, _DeviceArray], device: Optional[Device]): 1305 x = _copy_device_array_to_device(x, device) 1306 return (_force(x).device_buffer,) 1307device_put_handlers[_CppDeviceArray] = _device_put_device_array 1308device_put_handlers[_DeviceArray] = _device_put_device_array 1309 1310def _copy_device_array_to_device(x: Union[DeviceArrayProtocol, _DeviceArray], device: Optional[xc.Device]) -> Union[DeviceArrayProtocol, _DeviceArray]: 1311 if device is None: 1312 # no copying to be done because there's no target specified 1313 return x 1314 elif is_device_constant(x): 1315 # create a new DeviceArray with the same lazy expr, no copying 1316 return make_device_array(x.aval, device, x._lazy_expr, 1317 DeviceConstant(device)) 1318 elif xb.get_device_backend(device).platform == x.device_buffer.platform(): 1319 # source and target platforms are the same 1320 if x.device_buffer.device() == device: 1321 # no copying to be done because source equals target 1322 if x._device == device: 1323 return x 1324 else: 1325 moved_buf = x.device_buffer # We need to change stickyness 1326 else: 1327 # move the buffer with a device-to-device copy 1328 moved_buf = x.device_buffer.copy_to_device(device) 1329 else: 1330 # buffers from different XLA backends are passed through the host. 1331 backend = xb.get_device_backend(device) 1332 moved_buf = backend.buffer_from_pyval(x.device_buffer.to_py(), device) 1333 return _DeviceArray(x.aval, device, x._lazy_expr, moved_buf) 1334 1335def _force(x: DeviceArrayProtocol) -> DeviceArrayProtocol: 1336 if lazy.is_trivial(x._lazy_expr): 1337 return x 1338 else: 1339 # force x on the device where it lives, but preserve stickiness on result 1340 if x._device: 1341 device = x._device 1342 else: 1343 device = x.device_buffer.device() 1344 force_fun = _lazy_force_computation(x.aval, device, x._lazy_expr) 1345 result = force_fun(x) 1346 return make_device_array(x.aval, x._device, lazy.array(x.aval.shape), result) 1347 1348@cache() 1349def _lazy_force_computation(aval: core.ShapedArray, 1350 device: Device, lexpr: lazy.LazyExpr 1351 ) -> Callable[[_DeviceArray], PyLocalBuffer]: 1352 c = xb.make_computation_builder("lazy_force") 1353 if lazy.is_constant(lexpr): 1354 param = None 1355 else: 1356 idxs = [(src, dst) for dst, src in enumerate(lexpr.dims) if src is not None] 1357 param_shape = [None] * len(idxs) 1358 for src, dst in idxs: 1359 param_shape[src] = aval.shape[dst] 1360 param = xb.parameter(c, 0, xc.Shape.array_shape(aval.dtype, param_shape)) 1361 xla_out = lazy.stage_lexpr(c, lexpr, param) 1362 built_c = c.build(xla_out) 1363 1364 device = _device_from_arg_devices([device]) 1365 options = xb.get_compile_options( 1366 num_replicas=1, 1367 num_partitions=1, 1368 device_assignment=device and (device.id,)) 1369 compiled = backend_compile(xb.get_device_backend(device), built_c, options) 1370 1371 force_fun: Callable[[_DeviceArray], PyLocalBuffer] 1372 if lazy.is_constant(lexpr): 1373 def force_fun(_): 1374 return compiled.execute([])[0] 1375 else: 1376 def force_fun(x): 1377 return compiled.execute([x.device_buffer])[0] 1378 return force_fun 1379 1380 1381def _device_put_impl(x, device: Optional[Device] = None): 1382 if type_is_device_array(x): 1383 return _copy_device_array_to_device(x, device) 1384 1385 try: 1386 a = abstractify(x) 1387 except TypeError as err: 1388 raise TypeError( 1389 f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err 1390 return aval_to_result_handler(device, a)(*device_put(x, device)) 1391 1392device_put_p = core.Primitive('device_put') 1393device_put_p.def_impl(_device_put_impl) 1394device_put_p.def_abstract_eval(lambda x, device=None: x) 1395translations[device_put_p] = lambda c, x, device=None: x 1396ad.deflinear2(device_put_p, lambda cotangent, _, **kwargs: [cotangent]) 1397masking.defvectorized(device_put_p) 1398 1399 1400def _remat_translation_rule(c, axis_env, in_nodes, 1401 name_stack, backend, name, call_jaxpr, 1402 device=None, concrete=None): 1403 """Lower remat to a Conditional which always returns true. This: 1404 1. Circumvents common subexpression elimination. 1405 2. In common case of `jax.grad(jax.remat(f))`, ensures the remat blocks 1406 occur after the primal blocks, because cotangent is an input to the 1407 Conditional.""" 1408 del device, concrete # Unused. 1409 # Fake condition which always selects True branch. 1410 rng = xops.RngUniform(xb.constant(c, np.array(0, dtype=np.float32)), 1411 xb.constant(c, np.array(1, dtype=np.float32)), 1412 xc.Shape.array_shape(xc.PrimitiveType.F32, [])) 1413 pred = xops.Lt(rng, xb.constant(c, np.array(2, dtype=np.float32))) 1414 1415 true_op = xops.Tuple(c, in_nodes) 1416 remat_subc = xb.make_computation_builder("remat_call_subcomputation") 1417 input_op = xb.parameter(remat_subc, 0, c.get_shape(true_op), replicated=[]) 1418 args = [xops.GetTupleElement(input_op, i) for i in range(len(in_nodes))] 1419 out_nodes = jaxpr_subcomp(remat_subc, call_jaxpr, backend, axis_env, (), 1420 extend_name_stack(name_stack, wrap_name(name, 'remat')), 1421 *args) 1422 out_node_shapes = [remat_subc.get_shape(o) for o in out_nodes] 1423 remat_subc = remat_subc.build(xops.Tuple(remat_subc, out_nodes)) 1424 1425 false_op = true_op 1426 dummy_subc = xb.make_computation_builder("remat_call_dummy_subcomputation") 1427 xb.parameter(dummy_subc, 0, c.get_shape(false_op), replicated=[]) 1428 1429 def zeros(xla_shape): 1430 if xla_shape.is_array(): 1431 shape, dtype = xla_shape.dimensions(), xla_shape.numpy_dtype() 1432 zero = xb.constant(dummy_subc, np.array(0, dtype=dtype)) 1433 return xops.Broadcast(zero, shape) 1434 else: 1435 # It is a token 1436 return xops.CreateToken(dummy_subc) 1437 out_nodes = [zeros(s) for s in out_node_shapes] 1438 dummy_subc = dummy_subc.build(xops.Tuple(dummy_subc, out_nodes)) 1439 1440 return xops.Conditional(pred, true_op, remat_subc, false_op, dummy_subc) 1441call_translations[pe.remat_call_p] = _remat_translation_rule # type: ignore 1442 1443 1444ad.primitive_transposes[core.named_call_p] = partial(ad.call_transpose, 1445 core.named_call_p) 1446 1447 1448def _named_call_translation_rule(c, axis_env, in_nodes, name_stack, *, 1449 name="core_call", backend, call_jaxpr): 1450 subc = xb.make_computation_builder(name) 1451 args = [xb.parameter(subc, i, c.GetShape(n)) for i, n in enumerate(in_nodes)] 1452 out_nodes = jaxpr_subcomp(subc, call_jaxpr, backend, axis_env, (), 1453 extend_name_stack(name_stack, name), *args) 1454 subc = subc.Build(xops.Tuple(subc, out_nodes)) 1455 return xops.Call(c, subc, list(in_nodes)) 1456call_translations[core.named_call_p] = _named_call_translation_rule 1457 1458 1459def _call_translation_rule(c, axis_env, in_nodes, name_stack, *, backend, 1460 call_jaxpr): 1461 return _named_call_translation_rule( 1462 c, axis_env, in_nodes, name_stack, name="core_call", 1463 backend=backend, call_jaxpr=call_jaxpr) 1464call_translations[core.call_p] = _call_translation_rule 1465 1466 1467@config.register_omnistaging_disabler 1468def omnistaging_disabler() -> None: 1469 global _pval_to_result_handler 1470 1471 def _pval_to_result_handler(device, pval): 1472 pv, const = pval 1473 if pv is None: 1474 const = _device_put_impl(const, device) if device else const 1475 return lambda _: const 1476 else: 1477 return aval_to_result_handler(device, pv) 1478 1479 pe.staged_out_calls.add(xla_call_p) # type: ignore 1480