1# Copyright 2018 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Interface and utility functions to XLA. 16 17This module wraps the XLA client(s) and builders to standardize their interfaces 18and provide some automatic type mapping logic for converting between Numpy and 19XLA. There are also a handful of related casting utilities. 20""" 21 22 23from functools import partial, lru_cache 24import os 25from typing import Callable, Dict, List, Optional, Tuple, Union 26 27from absl import logging 28# Disable "WARNING: Logging before flag parsing goes to stderr." message 29logging._warn_preinit_stderr = 0 30 31from ..config import flags 32from jax._src import util 33from .. import dtypes 34import numpy as np 35import threading 36 37try: 38 from . import tpu_client 39except ImportError: 40 tpu_client = None 41from . import xla_client 42 43xops = xla_client.ops 44 45FLAGS = flags.FLAGS 46 47flags.DEFINE_string( 48 'jax_xla_backend', 'xla', 49 'Default is "xla" for the XLA service directly, ' 50 'or "tpu_driver" for using high-performance access to Cloud TPU hardware.') 51flags.DEFINE_string( 52 'jax_backend_target', 'local', 53 'Either "local" or "rpc:address" to connect to a remote service target.') 54flags.DEFINE_string( 55 'jax_platform_name', 56 os.getenv('JAX_PLATFORM_NAME', ''), 57 'Platform name for XLA. The default is to attempt to use a GPU if ' 58 'available, but fall back to CPU otherwise. To set the platform manually, ' 59 'pass "cpu" for CPU or "gpu" for GPU.') 60flags.DEFINE_bool( 61 'jax_disable_most_optimizations', False, 62 'Try not to do much optimization work. This can be useful if the cost of ' 63 'optimization is greater than that of running a less-optimized program.') 64 65 66def get_compile_options( 67 num_replicas: int, 68 num_partitions: int, 69 device_assignment=None, 70 use_spmd_partitioning: bool = True, 71) -> xla_client.CompileOptions: 72 """Returns the compile options to use, as derived from flag values. 73 74 Args: 75 num_replicas: Number of replicas for which to compile. 76 num_partitions: Number of partitions for which to compile. 77 device_assignment: Optional tuple of integers indicating the assignment of 78 logical replicas to physical devices (default inherited from 79 xla_client.CompileOptions). Must be consistent with `num_replicas` and 80 `num_partitions`. 81 use_spmd_partitioning: boolean indicating whether to enable SPMD or MPMD 82 partitioning in XLA. 83 """ 84 compile_options = xla_client.CompileOptions() 85 compile_options.num_replicas = num_replicas 86 compile_options.num_partitions = num_partitions 87 build_options = compile_options.executable_build_options 88 build_options.use_spmd_partitioning = use_spmd_partitioning 89 if device_assignment is not None: 90 logging.vlog( 91 2, 92 'get_compile_options: num_replicas=%s num_partitions=%s device_assignment=%s', 93 num_replicas, num_partitions, device_assignment) 94 device_assignment = np.array(device_assignment) 95 96 # Allow 1D device assignment if num_partitions is 1. 97 if (device_assignment.ndim == 1) and (num_partitions == 1): 98 device_assignment = device_assignment[:, None] 99 100 if num_replicas != device_assignment.shape[0]: 101 msg = 'device_assignment does not match num_replicas: {} vs {}.' 102 raise ValueError(msg.format(device_assignment, num_replicas)) 103 104 if num_partitions != device_assignment.shape[1]: 105 msg = 'device_assignment does not match num_partitions: {} vs {}.' 106 raise ValueError(msg.format(device_assignment, num_partitions)) 107 108 device_assignment = xla_client.DeviceAssignment.create(device_assignment) 109 assert device_assignment.replica_count() == num_replicas 110 assert device_assignment.computation_count() == num_partitions 111 compile_options.device_assignment = device_assignment 112 113 if FLAGS.jax_disable_most_optimizations: 114 debug_options = compile_options.executable_build_options.debug_options 115 debug_options.xla_backend_optimization_level = 0 116 debug_options.xla_llvm_disable_expensive_passes = True 117 debug_options.xla_test_all_input_layouts = False 118 119 return compile_options 120 121_backends = {} 122 123def register_backend(name, factory): 124 _backends[name] = factory 125 126def _get_local_backend(platform=None): 127 if not platform: 128 platform = FLAGS.jax_platform_name or None 129 130 backend = xla_client.get_local_backend(platform) 131 if backend is None: 132 raise RuntimeError("No local XLA backends found.") 133 134 if backend.platform == 'cpu' and platform != 'cpu': 135 logging.warning('No GPU/TPU found, falling back to CPU. ' 136 '(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)') 137 138 return backend 139 140 141register_backend('xla', _get_local_backend) 142 143# memoize the TPU driver to be consistent with xla_client behavior 144_tpu_backend = None 145 146def _get_tpu_driver_backend(platform): 147 del platform 148 global _tpu_backend 149 if _tpu_backend is None: 150 backend_target = FLAGS.jax_backend_target 151 if backend_target is None: 152 raise ValueError('When using TPU Driver as the backend, you must specify ' 153 '--jax_backend_target=<hostname>:8470.') 154 _tpu_backend = tpu_client.TpuBackend.create(worker=backend_target) 155 return _tpu_backend 156 157 158if tpu_client: 159 register_backend('tpu_driver', _get_tpu_driver_backend) 160 161 162_backend_lock = threading.Lock() 163 164@lru_cache(maxsize=None) # don't use util.memoize because there is no X64 dependence. 165def get_backend(platform=None): 166 # TODO(mattjj,skyewm): remove this input polymorphism after we clean up how 167 # 'backend' values are handled 168 if not isinstance(platform, (type(None), str)): 169 return platform 170 171 with _backend_lock: 172 backend = _backends.get(FLAGS.jax_xla_backend) 173 if backend is None: 174 msg = 'Unknown jax_xla_backend value "{}".' 175 raise ValueError(msg.format(FLAGS.jax_xla_backend)) 176 return backend(platform) 177 178 179def get_device_backend(device=None): 180 """Returns the Backend associated with `device`, or the default Backend.""" 181 platform = device.platform if device else None 182 return get_backend(platform) 183 184 185def device_count(backend: Optional[str] = None) -> int: 186 """Returns the total number of devices. 187 188 On most platforms, this is the same as :py:func:`jax.local_device_count`. 189 However, on multi-host platforms, this will return the total number of devices 190 across all hosts. 191 192 Args: 193 backend: This is an experimental feature and the API is likely to change. 194 Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or 195 ``'tpu'``. 196 197 Returns: 198 Number of devices. 199 """ 200 return int(get_backend(backend).device_count()) 201 202 203def local_device_count(backend: Optional[str] = None) -> int: 204 """Returns the number of devices on this host.""" 205 return int(get_backend(backend).local_device_count()) 206 207 208def devices(backend: Optional[str] = None) -> List[xla_client.Device]: 209 """Returns a list of all devices for a given backend. 210 211 Each device is represented by a subclass of :class:`Device` (e.g. 212 :class:`CpuDevice`, :class:`GpuDevice`). The length of the returned list is 213 equal to ``device_count(backend)``. Local devices can be identified by comparing 214 :meth:`Device.host_id` to the value returned by :py:func:`jax.host_id`. 215 216 If ``backend`` is ``None``, returns all the devices from the default backend. 217 The default backend is generally ``'gpu'`` or ``'tpu'`` if available, 218 otherwise ``'cpu'``. 219 220 Args: 221 backend: This is an experimental feature and the API is likely to change. 222 Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or 223 ``'tpu'``. 224 225 Returns: 226 List of Device subclasses. 227 """ 228 return get_backend(backend).devices() 229 230 231def local_devices(host_id: Optional[int] = None, 232 backend: Optional[str] = None) -> List[xla_client.Device]: 233 """Like :py:func:`jax.devices`, but only returns devices local to a given host. 234 235 If ``host_id`` is ``None``, returns devices local to this host. 236 237 Args: 238 host_id: the integer ID of the host. Host IDs can be retrieved via 239 :py:func:`jax.host_ids`. 240 backend: This is an experimental feature and the API is likely to change. 241 Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or 242 ``'tpu'``. 243 244 Returns: 245 List of Device subclasses. 246 """ 247 if host_id is None: 248 host_id = get_backend(backend).host_id() 249 if host_id not in host_ids(): 250 raise ValueError(f"Unknown host_id {host_id}") 251 return [d for d in devices(backend) if d.host_id == host_id] 252 253 254def host_id(backend: Optional[str] = None) -> int: 255 """Returns the integer host ID of this host. 256 257 On most platforms, this will always be 0. This will vary on multi-host 258 platforms though. 259 260 Args: 261 backend: This is an experimental feature and the API is likely to change. 262 Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or 263 ``'tpu'``. 264 265 Returns: 266 Integer host ID. 267 """ 268 return get_backend(backend).host_id() 269 270 271def host_ids(backend: Optional[str] = None) -> List[int]: 272 """Returns a sorted list of all host IDs.""" 273 return sorted({d.host_id for d in devices(backend)}) 274 275 276def host_count(backend: Optional[str] = None) -> int: 277 """Returns the number of hosts.""" 278 return len(host_ids(backend)) 279 280 281### utility functions 282 283@util.memoize 284def dtype_to_etype(dtype): 285 """Convert from dtype to canonical etype (reading FLAGS.jax_enable_x64).""" 286 return xla_client.dtype_to_etype(dtypes.canonicalize_dtype(dtype)) 287 288 289@util.memoize 290def supported_numpy_dtypes(): 291 return {dtypes.canonicalize_dtype(dtype) 292 for dtype in xla_client.XLA_ELEMENT_TYPE_TO_DTYPE.values()} 293 294 295# TODO(mattjj,frostig): try to remove this function 296def normalize_to_xla_dtypes(val): 297 """Normalize dtypes in a value.""" 298 if hasattr(val, '__array__') or np.isscalar(val): 299 return np.asarray(val, dtype=dtypes.canonicalize_dtype(dtypes.result_type(val))) 300 elif isinstance(val, (tuple, list)): 301 return tuple(normalize_to_xla_dtypes(x) for x in val) 302 raise TypeError('Can\'t convert to XLA: {}'.format(val)) 303 304def _numpy_array_constant(builder, value, canonicalize_types=True): 305 if canonicalize_types: 306 value = normalize_to_xla_dtypes(value) 307 return xops.ConstantLiteral(builder, value) 308 309def parameter(builder, num, shape, name=None, replicated=None): 310 if name is None: 311 name = '' 312 if replicated is None: 313 replicated = [] 314 elif isinstance(replicated, bool): 315 replicated = [replicated] * shape.leaf_count() 316 317 return xops.Parameter(builder, num, 318 shape.with_major_to_minor_layout_if_absent(), name, 319 replicated) 320 321 322def constant(builder, py_val, canonicalize_types=True): 323 """Translate constant `py_val` to a constant, canonicalizing its dtype. 324 325 Args: 326 py_val: a Python value to be translated to a constant. 327 328 Returns: 329 A representation of the constant, either a ComputationDataHandle or None 330 """ 331 py_type = type(py_val) 332 if py_type in _constant_handlers: 333 return _constant_handlers[py_type](builder, py_val, canonicalize_types) 334 else: 335 raise TypeError("No constant handler for type: {}".format(py_type)) 336 337# HLO instructions optionally can be annotated to say how the output should be 338# spatially partitioned (represented in XLA as OpSharding protos, see 339# _sharding_to_proto). For array outputs, the annotation is either an int per 340# dimension specifying the number of ways that dimension divided (i.e. the total 341# number of shards is the product), or None to indicate the array should be 342# replicated. Tuple outputs are represented as tuples thereof. XLA supports 343# arbitrary tuple nesting, but JAX only uses one level of tupling (and our type 344# checkers don't support recursive types), so we only represent one level of 345# nesting in this type definition. 346SpatialSharding = Union[Tuple[int, ...], 347 None, 348 Tuple[Union[Tuple[int, ...], None], ...]] 349 350def _sharding_to_proto(sharding: SpatialSharding): 351 """Converts a SpatialSharding to an OpSharding. 352 353 See 354 https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/xla_data.proto#L601 355 for details on the OpSharding proto. 356 """ 357 proto = xla_client.OpSharding() 358 if isinstance(sharding, tuple) and not isinstance(sharding[0], int): 359 assert all(s is None or isinstance(s, tuple) for s in sharding) 360 return tuple_sharding_proto(list(map(_sharding_to_proto, sharding))) # type: ignore 361 362 if sharding is None: 363 proto.type = xla_client.OpSharding.Type.REPLICATED 364 else: 365 proto.type = xla_client.OpSharding.Type.OTHER 366 proto.tile_assignment_dimensions = list(sharding) 367 proto.tile_assignment_devices = list(range(np.product(sharding))) 368 return proto 369 370def tuple_sharding_proto(elems): 371 proto = xla_client.OpSharding() 372 assert all(isinstance(e, type(proto)) for e in elems) 373 proto.type = xla_client.OpSharding.Type.TUPLE 374 proto.tuple_shardings = elems 375 return proto 376 377def set_sharding_proto(builder, op, sharding_proto): 378 """Uses CustomCall to annotate a value as sharded.""" 379 # "Sharding" is a built-in custom call target that acts like an identity 380 # function, and is used to attach an OpSharding to. 381 return with_sharding_proto(builder, sharding_proto, xops.CustomCall, 382 builder, b"Sharding", [op], builder.get_shape(op)) 383 384def with_sharding_proto(builder, sharding_proto, op_fn, *args, **kwargs): 385 """Builds op_fn(*args, **kwargs) with sharding annotation.""" 386 builder.set_sharding(sharding_proto) 387 try: 388 return op_fn(*args, **kwargs) 389 finally: 390 builder.clear_sharding() 391 392def set_sharding(builder, op, sharding: SpatialSharding): 393 """Uses CustomCall to annotate a value as sharded.""" 394 return set_sharding_proto(builder, op, _sharding_to_proto(sharding)) 395 396def with_sharding(builder, sharding: SpatialSharding, op_fn, *args, **kwargs): 397 """Builds op_fn(*args, **kwargs) with sharding annotation.""" 398 return with_sharding_proto(builder, _sharding_to_proto(sharding), op_fn, *args, **kwargs) 399 400def make_computation_builder(name): 401 return xla_client.XlaBuilder(name) 402 403 404def register_constant_handler(type_, handler_fun): 405 _constant_handlers[type_] = handler_fun 406_constant_handlers: Dict[type, Callable] = {} 407 408 409def _ndarray_constant_handler(c, val, canonicalize_types=True): 410 """Constant handler for ndarray literals, handling zero-size strides. 411 412 This function essentially calls _numpy_array_constant(val) except it has 413 special handling of arrays with any strides of size zero: for those, it 414 generates appropriate calls to NumpyArrayConstant, Broadcast, and Transpose 415 to avoid staging in large literals that might arise from np.zeros or np.ones 416 or the output of lax.broadcast (which uses np.broadcast_to which in turn 417 uses size-zero strides). 418 419 Args: 420 c: an XlaBuilder 421 val: an ndarray. 422 423 Returns: 424 An XLA ComputationDataHandle / XlaOp representing the constant ndarray 425 staged into the XLA Computation. 426 """ 427 # TODO(mattjj): revise this to use xops.BroadcastInDim rather than Transpose 428 if dtypes.result_type(val) == dtypes.float0: 429 return _numpy_array_constant(c, np.zeros(val.shape, dtype=np.bool)) 430 elif np.any(np.equal(0, val.strides)) and val.size > 0: 431 zero_stride_axes, = np.where(np.equal(0, val.strides)) 432 other_axes, = np.where(np.not_equal(0, val.strides)) 433 collapsed_val = val[tuple(0 if ax in zero_stride_axes else slice(None) 434 for ax in range(val.ndim))] 435 xla_val = xops.Broadcast( 436 _numpy_array_constant(c, collapsed_val, canonicalize_types), 437 np.take(val.shape, zero_stride_axes)) 438 permutation = np.argsort(tuple(zero_stride_axes) + tuple(other_axes)) 439 return xops.Transpose(xla_val, permutation) 440 else: 441 return _numpy_array_constant(c, val, canonicalize_types) 442register_constant_handler(np.ndarray, _ndarray_constant_handler) 443 444 445def _scalar_constant_handler(c, val, canonicalize_types=True): 446 return _numpy_array_constant(c, val, canonicalize_types) 447 448for scalar_type in [np.int8, np.int16, np.int32, np.int64, 449 np.uint8, np.uint16, np.uint32, np.uint64, 450 np.float16, np.float32, np.float64, 451 np.bool_, np.longlong, 452 xla_client.bfloat16]: 453 register_constant_handler(scalar_type, _scalar_constant_handler) 454 455# https://github.com/winpython/winpython/issues/613#issuecomment-380121523 456if hasattr(np, "float128"): 457 register_constant_handler(np.float128, _scalar_constant_handler) 458 459def _python_scalar_handler(dtype, c, val, canonicalize_dtypes=True): 460 return _numpy_array_constant(c, dtype.type(val)) 461 462for ptype, dtype in dtypes.python_scalar_dtypes.items(): 463 register_constant_handler(ptype, partial(_python_scalar_handler, dtype)) 464