1# Copyright 2018 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Interface and utility functions to XLA.
16
17This module wraps the XLA client(s) and builders to standardize their interfaces
18and provide some automatic type mapping logic for converting between Numpy and
19XLA. There are also a handful of related casting utilities.
20"""
21
22
23from functools import partial, lru_cache
24import os
25from typing import Callable, Dict, List, Optional, Tuple, Union
26
27from absl import logging
28# Disable "WARNING: Logging before flag parsing goes to stderr." message
29logging._warn_preinit_stderr = 0
30
31from ..config import flags
32from jax._src import util
33from .. import dtypes
34import numpy as np
35import threading
36
37try:
38  from . import tpu_client
39except ImportError:
40  tpu_client = None
41from . import xla_client
42
43xops = xla_client.ops
44
45FLAGS = flags.FLAGS
46
47flags.DEFINE_string(
48    'jax_xla_backend', 'xla',
49    'Default is "xla" for the XLA service directly, '
50    'or "tpu_driver" for using high-performance access to Cloud TPU hardware.')
51flags.DEFINE_string(
52    'jax_backend_target', 'local',
53    'Either "local" or "rpc:address" to connect to a remote service target.')
54flags.DEFINE_string(
55    'jax_platform_name',
56    os.getenv('JAX_PLATFORM_NAME', ''),
57    'Platform name for XLA. The default is to attempt to use a GPU if '
58    'available, but fall back to CPU otherwise. To set the platform manually, '
59    'pass "cpu" for CPU or "gpu" for GPU.')
60flags.DEFINE_bool(
61    'jax_disable_most_optimizations', False,
62    'Try not to do much optimization work. This can be useful if the cost of '
63    'optimization is greater than that of running a less-optimized program.')
64
65
66def get_compile_options(
67    num_replicas: int,
68    num_partitions: int,
69    device_assignment=None,
70    use_spmd_partitioning: bool = True,
71) -> xla_client.CompileOptions:
72  """Returns the compile options to use, as derived from flag values.
73
74  Args:
75    num_replicas: Number of replicas for which to compile.
76    num_partitions: Number of partitions for which to compile.
77    device_assignment: Optional tuple of integers indicating the assignment of
78      logical replicas to physical devices (default inherited from
79      xla_client.CompileOptions). Must be consistent with `num_replicas` and
80      `num_partitions`.
81    use_spmd_partitioning: boolean indicating whether to enable SPMD or MPMD
82      partitioning in XLA.
83  """
84  compile_options = xla_client.CompileOptions()
85  compile_options.num_replicas = num_replicas
86  compile_options.num_partitions = num_partitions
87  build_options = compile_options.executable_build_options
88  build_options.use_spmd_partitioning = use_spmd_partitioning
89  if device_assignment is not None:
90    logging.vlog(
91        2,
92        'get_compile_options: num_replicas=%s num_partitions=%s device_assignment=%s',
93        num_replicas, num_partitions, device_assignment)
94    device_assignment = np.array(device_assignment)
95
96    # Allow 1D device assignment if num_partitions is 1.
97    if (device_assignment.ndim == 1) and (num_partitions == 1):
98      device_assignment = device_assignment[:, None]
99
100    if num_replicas != device_assignment.shape[0]:
101      msg = 'device_assignment does not match num_replicas: {} vs {}.'
102      raise ValueError(msg.format(device_assignment, num_replicas))
103
104    if num_partitions != device_assignment.shape[1]:
105      msg = 'device_assignment does not match num_partitions: {} vs {}.'
106      raise ValueError(msg.format(device_assignment, num_partitions))
107
108    device_assignment = xla_client.DeviceAssignment.create(device_assignment)
109    assert device_assignment.replica_count() == num_replicas
110    assert device_assignment.computation_count() == num_partitions
111    compile_options.device_assignment = device_assignment
112
113  if FLAGS.jax_disable_most_optimizations:
114    debug_options = compile_options.executable_build_options.debug_options
115    debug_options.xla_backend_optimization_level = 0
116    debug_options.xla_llvm_disable_expensive_passes = True
117    debug_options.xla_test_all_input_layouts = False
118
119  return compile_options
120
121_backends = {}
122
123def register_backend(name, factory):
124  _backends[name] = factory
125
126def _get_local_backend(platform=None):
127  if not platform:
128    platform = FLAGS.jax_platform_name or None
129
130  backend = xla_client.get_local_backend(platform)
131  if backend is None:
132    raise RuntimeError("No local XLA backends found.")
133
134  if backend.platform == 'cpu' and platform != 'cpu':
135    logging.warning('No GPU/TPU found, falling back to CPU. '
136                    '(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)')
137
138  return backend
139
140
141register_backend('xla', _get_local_backend)
142
143# memoize the TPU driver to be consistent with xla_client behavior
144_tpu_backend = None
145
146def _get_tpu_driver_backend(platform):
147  del platform
148  global _tpu_backend
149  if _tpu_backend is None:
150    backend_target = FLAGS.jax_backend_target
151    if backend_target is None:
152      raise ValueError('When using TPU Driver as the backend, you must specify '
153                       '--jax_backend_target=<hostname>:8470.')
154    _tpu_backend = tpu_client.TpuBackend.create(worker=backend_target)
155  return _tpu_backend
156
157
158if tpu_client:
159  register_backend('tpu_driver', _get_tpu_driver_backend)
160
161
162_backend_lock = threading.Lock()
163
164@lru_cache(maxsize=None)  # don't use util.memoize because there is no X64 dependence.
165def get_backend(platform=None):
166  # TODO(mattjj,skyewm): remove this input polymorphism after we clean up how
167  # 'backend' values are handled
168  if not isinstance(platform, (type(None), str)):
169    return platform
170
171  with _backend_lock:
172    backend = _backends.get(FLAGS.jax_xla_backend)
173    if backend is None:
174      msg = 'Unknown jax_xla_backend value "{}".'
175      raise ValueError(msg.format(FLAGS.jax_xla_backend))
176    return backend(platform)
177
178
179def get_device_backend(device=None):
180  """Returns the Backend associated with `device`, or the default Backend."""
181  platform = device.platform if device else None
182  return get_backend(platform)
183
184
185def device_count(backend: Optional[str] = None) -> int:
186  """Returns the total number of devices.
187
188  On most platforms, this is the same as :py:func:`jax.local_device_count`.
189  However, on multi-host platforms, this will return the total number of devices
190  across all hosts.
191
192  Args:
193    backend: This is an experimental feature and the API is likely to change.
194      Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
195      ``'tpu'``.
196
197  Returns:
198    Number of devices.
199  """
200  return int(get_backend(backend).device_count())
201
202
203def local_device_count(backend: Optional[str] = None) -> int:
204  """Returns the number of devices on this host."""
205  return int(get_backend(backend).local_device_count())
206
207
208def devices(backend: Optional[str] = None) -> List[xla_client.Device]:
209  """Returns a list of all devices for a given backend.
210
211  Each device is represented by a subclass of :class:`Device` (e.g.
212  :class:`CpuDevice`, :class:`GpuDevice`). The length of the returned list is
213  equal to ``device_count(backend)``. Local devices can be identified by comparing
214  :meth:`Device.host_id` to the value returned by :py:func:`jax.host_id`.
215
216  If ``backend`` is ``None``, returns all the devices from the default backend.
217  The default backend is generally ``'gpu'`` or ``'tpu'`` if available,
218  otherwise ``'cpu'``.
219
220  Args:
221    backend: This is an experimental feature and the API is likely to change.
222      Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
223      ``'tpu'``.
224
225  Returns:
226    List of Device subclasses.
227  """
228  return get_backend(backend).devices()
229
230
231def local_devices(host_id: Optional[int] = None,
232                  backend: Optional[str] = None) -> List[xla_client.Device]:
233  """Like :py:func:`jax.devices`, but only returns devices local to a given host.
234
235  If ``host_id`` is ``None``, returns devices local to this host.
236
237  Args:
238    host_id: the integer ID of the host. Host IDs can be retrieved via
239      :py:func:`jax.host_ids`.
240    backend: This is an experimental feature and the API is likely to change.
241      Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
242      ``'tpu'``.
243
244  Returns:
245    List of Device subclasses.
246  """
247  if host_id is None:
248    host_id = get_backend(backend).host_id()
249  if host_id not in host_ids():
250    raise ValueError(f"Unknown host_id {host_id}")
251  return [d for d in devices(backend) if d.host_id == host_id]
252
253
254def host_id(backend: Optional[str] = None) -> int:
255  """Returns the integer host ID of this host.
256
257  On most platforms, this will always be 0. This will vary on multi-host
258  platforms though.
259
260  Args:
261    backend: This is an experimental feature and the API is likely to change.
262      Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
263      ``'tpu'``.
264
265  Returns:
266    Integer host ID.
267  """
268  return get_backend(backend).host_id()
269
270
271def host_ids(backend: Optional[str] = None) -> List[int]:
272  """Returns a sorted list of all host IDs."""
273  return sorted({d.host_id for d in devices(backend)})
274
275
276def host_count(backend: Optional[str] = None) -> int:
277  """Returns the number of hosts."""
278  return len(host_ids(backend))
279
280
281### utility functions
282
283@util.memoize
284def dtype_to_etype(dtype):
285  """Convert from dtype to canonical etype (reading FLAGS.jax_enable_x64)."""
286  return xla_client.dtype_to_etype(dtypes.canonicalize_dtype(dtype))
287
288
289@util.memoize
290def supported_numpy_dtypes():
291  return {dtypes.canonicalize_dtype(dtype)
292          for dtype in xla_client.XLA_ELEMENT_TYPE_TO_DTYPE.values()}
293
294
295# TODO(mattjj,frostig): try to remove this function
296def normalize_to_xla_dtypes(val):
297  """Normalize dtypes in a value."""
298  if hasattr(val, '__array__') or np.isscalar(val):
299    return np.asarray(val, dtype=dtypes.canonicalize_dtype(dtypes.result_type(val)))
300  elif isinstance(val, (tuple, list)):
301    return tuple(normalize_to_xla_dtypes(x) for x in val)
302  raise TypeError('Can\'t convert to XLA: {}'.format(val))
303
304def _numpy_array_constant(builder, value, canonicalize_types=True):
305  if canonicalize_types:
306    value = normalize_to_xla_dtypes(value)
307  return xops.ConstantLiteral(builder, value)
308
309def parameter(builder, num, shape, name=None, replicated=None):
310  if name is None:
311    name = ''
312  if replicated is None:
313    replicated = []
314  elif isinstance(replicated, bool):
315    replicated = [replicated] * shape.leaf_count()
316
317  return xops.Parameter(builder, num,
318                        shape.with_major_to_minor_layout_if_absent(), name,
319                        replicated)
320
321
322def constant(builder, py_val, canonicalize_types=True):
323  """Translate constant `py_val` to a constant, canonicalizing its dtype.
324
325  Args:
326    py_val: a Python value to be translated to a constant.
327
328  Returns:
329    A representation of the constant, either a ComputationDataHandle or None
330  """
331  py_type = type(py_val)
332  if py_type in _constant_handlers:
333    return _constant_handlers[py_type](builder, py_val, canonicalize_types)
334  else:
335    raise TypeError("No constant handler for type: {}".format(py_type))
336
337# HLO instructions optionally can be annotated to say how the output should be
338# spatially partitioned (represented in XLA as OpSharding protos, see
339# _sharding_to_proto). For array outputs, the annotation is either an int per
340# dimension specifying the number of ways that dimension divided (i.e. the total
341# number of shards is the product), or None to indicate the array should be
342# replicated. Tuple outputs are represented as tuples thereof. XLA supports
343# arbitrary tuple nesting, but JAX only uses one level of tupling (and our type
344# checkers don't support recursive types), so we only represent one level of
345# nesting in this type definition.
346SpatialSharding = Union[Tuple[int, ...],
347                        None,
348                        Tuple[Union[Tuple[int, ...], None], ...]]
349
350def _sharding_to_proto(sharding: SpatialSharding):
351  """Converts a SpatialSharding to an OpSharding.
352
353  See
354  https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/xla_data.proto#L601
355  for details on the OpSharding proto.
356  """
357  proto = xla_client.OpSharding()
358  if isinstance(sharding, tuple) and not isinstance(sharding[0], int):
359    assert all(s is None or isinstance(s, tuple) for s in sharding)
360    return tuple_sharding_proto(list(map(_sharding_to_proto, sharding)))  # type: ignore
361
362  if sharding is None:
363    proto.type = xla_client.OpSharding.Type.REPLICATED
364  else:
365    proto.type = xla_client.OpSharding.Type.OTHER
366    proto.tile_assignment_dimensions = list(sharding)
367    proto.tile_assignment_devices = list(range(np.product(sharding)))
368  return proto
369
370def tuple_sharding_proto(elems):
371  proto = xla_client.OpSharding()
372  assert all(isinstance(e, type(proto)) for e in elems)
373  proto.type = xla_client.OpSharding.Type.TUPLE
374  proto.tuple_shardings = elems
375  return proto
376
377def set_sharding_proto(builder, op, sharding_proto):
378  """Uses CustomCall to annotate a value as sharded."""
379  # "Sharding" is a built-in custom call target that acts like an identity
380  # function, and is used to attach an OpSharding to.
381  return with_sharding_proto(builder, sharding_proto, xops.CustomCall,
382                             builder, b"Sharding", [op], builder.get_shape(op))
383
384def with_sharding_proto(builder, sharding_proto, op_fn, *args, **kwargs):
385  """Builds op_fn(*args, **kwargs) with sharding annotation."""
386  builder.set_sharding(sharding_proto)
387  try:
388    return op_fn(*args, **kwargs)
389  finally:
390    builder.clear_sharding()
391
392def set_sharding(builder, op, sharding: SpatialSharding):
393  """Uses CustomCall to annotate a value as sharded."""
394  return set_sharding_proto(builder, op, _sharding_to_proto(sharding))
395
396def with_sharding(builder, sharding: SpatialSharding, op_fn, *args, **kwargs):
397  """Builds op_fn(*args, **kwargs) with sharding annotation."""
398  return with_sharding_proto(builder, _sharding_to_proto(sharding), op_fn, *args, **kwargs)
399
400def make_computation_builder(name):
401  return xla_client.XlaBuilder(name)
402
403
404def register_constant_handler(type_, handler_fun):
405  _constant_handlers[type_] = handler_fun
406_constant_handlers: Dict[type, Callable] = {}
407
408
409def _ndarray_constant_handler(c, val, canonicalize_types=True):
410  """Constant handler for ndarray literals, handling zero-size strides.
411
412  This function essentially calls _numpy_array_constant(val) except it has
413  special handling of arrays with any strides of size zero: for those, it
414  generates appropriate calls to NumpyArrayConstant, Broadcast, and Transpose
415  to avoid staging in large literals that might arise from np.zeros or np.ones
416  or the output of lax.broadcast (which uses np.broadcast_to which in turn
417  uses size-zero strides).
418
419  Args:
420    c: an XlaBuilder
421    val: an ndarray.
422
423  Returns:
424    An XLA ComputationDataHandle / XlaOp representing the constant ndarray
425    staged into the XLA Computation.
426  """
427  # TODO(mattjj): revise this to use xops.BroadcastInDim rather than Transpose
428  if dtypes.result_type(val) == dtypes.float0:
429    return _numpy_array_constant(c, np.zeros(val.shape, dtype=np.bool))
430  elif np.any(np.equal(0, val.strides)) and val.size > 0:
431    zero_stride_axes, = np.where(np.equal(0, val.strides))
432    other_axes, = np.where(np.not_equal(0, val.strides))
433    collapsed_val = val[tuple(0 if ax in zero_stride_axes else slice(None)
434                              for ax in range(val.ndim))]
435    xla_val = xops.Broadcast(
436        _numpy_array_constant(c, collapsed_val, canonicalize_types),
437        np.take(val.shape, zero_stride_axes))
438    permutation = np.argsort(tuple(zero_stride_axes) + tuple(other_axes))
439    return xops.Transpose(xla_val, permutation)
440  else:
441    return _numpy_array_constant(c, val, canonicalize_types)
442register_constant_handler(np.ndarray, _ndarray_constant_handler)
443
444
445def _scalar_constant_handler(c, val, canonicalize_types=True):
446  return _numpy_array_constant(c, val, canonicalize_types)
447
448for scalar_type in [np.int8, np.int16, np.int32, np.int64,
449                    np.uint8, np.uint16, np.uint32, np.uint64,
450                    np.float16, np.float32, np.float64,
451                    np.bool_, np.longlong,
452                    xla_client.bfloat16]:
453  register_constant_handler(scalar_type, _scalar_constant_handler)
454
455# https://github.com/winpython/winpython/issues/613#issuecomment-380121523
456if hasattr(np, "float128"):
457  register_constant_handler(np.float128, _scalar_constant_handler)
458
459def _python_scalar_handler(dtype, c, val, canonicalize_dtypes=True):
460  return _numpy_array_constant(c, dtype.type(val))
461
462for ptype, dtype in dtypes.python_scalar_dtypes.items():
463  register_constant_handler(ptype, partial(_python_scalar_handler, dtype))
464