1# Copyright 2019 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# Array type functions.
16#
17# JAX dtypes differ from NumPy in both:
18# a) their type promotion rules, and
19# b) the set of supported types (e.g., bfloat16),
20# so we need our own implementation that deviates from NumPy in places.
21
22
23from distutils.util import strtobool
24import functools
25import os
26from typing import Dict
27
28import numpy as np
29
30from ._src import util
31from .config import flags
32from .lib import xla_client
33
34from ._src import traceback_util
35traceback_util.register_exclusion(__file__)
36
37FLAGS = flags.FLAGS
38flags.DEFINE_bool('jax_enable_x64',
39                  strtobool(os.getenv('JAX_ENABLE_X64', 'False')),
40                  'Enable 64-bit types to be used.')
41
42# bfloat16 support
43bfloat16 = xla_client.bfloat16
44_bfloat16_dtype = np.dtype(bfloat16)
45
46# Default types.
47
48bool_ = np.bool_
49int_ = np.int64
50float_ = np.float64
51complex_ = np.complex128
52
53# TODO(phawkins): change the above defaults to:
54# int_ = np.int32
55# float_ = np.float32
56# complex_ = np.complex64
57
58# Trivial vectorspace datatype needed for tangent values of int/bool primals
59float0 = np.dtype([('float0', np.void, 0)])
60
61_dtype_to_32bit_dtype = {
62    np.dtype('int64'): np.dtype('int32'),
63    np.dtype('uint64'): np.dtype('uint32'),
64    np.dtype('float64'): np.dtype('float32'),
65    np.dtype('complex128'): np.dtype('complex64'),
66}
67
68@util.memoize
69def canonicalize_dtype(dtype):
70  """Convert from a dtype to a canonical dtype based on FLAGS.jax_enable_x64."""
71  if isinstance(dtype, str) and dtype == "bfloat16":
72    dtype = bfloat16
73  try:
74    dtype = np.dtype(dtype)
75  except TypeError as e:
76    raise TypeError(f'dtype {dtype!r} not understood') from e
77
78  if FLAGS.jax_enable_x64:
79    return dtype
80  else:
81    return _dtype_to_32bit_dtype.get(dtype, dtype)
82
83
84# Default dtypes corresponding to Python scalars.
85python_scalar_dtypes = {
86  bool: np.dtype(bool_),
87  int: np.dtype(int_),
88  float: np.dtype(float_),
89  complex: np.dtype(complex_),
90  float0: float0
91}
92
93def scalar_type_of(x):
94  typ = dtype(x)
95  if np.issubdtype(typ, np.bool_):
96    return bool
97  elif np.issubdtype(typ, np.integer):
98    return int
99  elif np.issubdtype(typ, np.floating):
100    return float
101  elif np.issubdtype(typ, np.complexfloating):
102    return complex
103  else:
104    raise TypeError("Invalid scalar value {}".format(x))
105
106def coerce_to_array(x):
107  """Coerces a scalar or NumPy array to an np.array.
108
109  Handles Python scalar type promotion according to JAX's rules, not NumPy's
110  rules.
111  """
112  dtype = python_scalar_dtypes.get(type(x), None)
113  return np.array(x, dtype) if dtype else np.array(x)
114
115iinfo = np.iinfo
116
117class finfo(np.finfo):
118  __doc__ = np.finfo.__doc__
119  _finfo_cache: Dict[np.dtype, np.finfo] = {}
120  @staticmethod
121  def _bfloat16_finfo():
122    def float_to_str(f):
123      return "%12.4e" % float(f)
124
125    bfloat16 = _bfloat16_dtype.type
126    tiny = float.fromhex("0x1p-126")
127    resolution = 0.01
128    eps = float.fromhex("0x1p-7")
129    epsneg = float.fromhex("0x1p-8")
130    max = float.fromhex("0x1.FEp127")
131
132    obj = object.__new__(np.finfo)
133    obj.dtype = _bfloat16_dtype
134    obj.bits = 16
135    obj.eps = bfloat16(eps)
136    obj.epsneg = bfloat16(epsneg)
137    obj.machep = -7
138    obj.negep = -8
139    obj.max = bfloat16(max)
140    obj.min = bfloat16(-max)
141    obj.nexp = 8
142    obj.nmant = 7
143    obj.iexp = obj.nexp
144    obj.precision = 2
145    obj.resolution = bfloat16(resolution)
146    obj.tiny = bfloat16(tiny)
147    obj.machar = None  # np.core.getlimits.MachArLike does not support bfloat16.
148
149    obj._str_tiny = float_to_str(tiny)
150    obj._str_max = float_to_str(max)
151    obj._str_epsneg = float_to_str(epsneg)
152    obj._str_eps = float_to_str(eps)
153    obj._str_resolution = float_to_str(resolution)
154    return obj
155
156  def __new__(cls, dtype):
157    if isinstance(dtype, str) and dtype == 'bfloat16' or dtype == _bfloat16_dtype:
158      if _bfloat16_dtype not in cls._finfo_cache:
159        cls._finfo_cache[_bfloat16_dtype] = cls._bfloat16_finfo()
160      return cls._finfo_cache[_bfloat16_dtype]
161    return super().__new__(cls, dtype)
162
163def _issubclass(a, b):
164  """Determines if ``a`` is a subclass of ``b``.
165
166  Similar to issubclass, but returns False instead of an exception if `a` is not
167  a class.
168  """
169  try:
170    return issubclass(a, b)
171  except TypeError:
172    return False
173
174def issubdtype(a, b):
175  if a == bfloat16:
176    if isinstance(b, np.dtype):
177      return b == _bfloat16_dtype
178    else:
179      return b in [bfloat16, np.floating, np.inexact, np.number]
180  if not _issubclass(b, np.generic):
181    # Workaround for JAX scalar types. NumPy's issubdtype has a backward
182    # compatibility behavior for the second argument of issubdtype that
183    # interacts badly with JAX's custom scalar types. As a workaround,
184    # explicitly cast the second argument to a NumPy type object.
185    b = np.dtype(b).type
186  return np.issubdtype(a, b)
187
188can_cast = np.can_cast
189issubsctype = np.issubsctype
190
191# Return the type holding the real part of the input type
192def dtype_real(typ):
193  if np.issubdtype(typ, np.complexfloating):
194    if typ == np.dtype('complex64'):
195      return np.dtype('float32')
196    elif typ == np.dtype('complex128'):
197      return np.dtype('float64')
198    else:
199      raise TypeError("Unknown complex floating type {}".format(typ))
200  else:
201    return typ
202
203# Enumeration of all valid JAX types in order.
204_weak_types = [int, float, complex]
205_jax_types = [
206  np.dtype('bool'),
207  np.dtype('uint8'),
208  np.dtype('uint16'),
209  np.dtype('uint32'),
210  np.dtype('uint64'),
211  np.dtype('int8'),
212  np.dtype('int16'),
213  np.dtype('int32'),
214  np.dtype('int64'),
215  np.dtype(bfloat16),
216  np.dtype('float16'),
217  np.dtype('float32'),
218  np.dtype('float64'),
219  np.dtype('complex64'),
220  np.dtype('complex128'),
221] + _weak_types
222
223def _jax_type(value):
224  """Return the jax type for a value or type."""
225  # Note: `x in _weak_types` can return false positives due to dtype comparator overloading.
226  if any(value is typ for typ in _weak_types):
227    return value
228  dtype_ = dtype(value)
229  if is_weakly_typed(value):
230    pytype = type(dtype_.type(0).item())
231    if pytype in _weak_types:
232      return pytype
233  return dtype_
234
235def _type_promotion_lattice():
236  """
237  Return the type promotion lattice in the form of a DAG.
238  This DAG maps each type to its immediately higher type on the lattice.
239  """
240  b1, u1, u2, u4, u8, i1, i2, i4, i8, bf, f2, f4, f8, c4, c8, i_, f_, c_ = _jax_types
241  return {
242    b1: [i_],
243    u1: [i2, u2], u2: [i4, u4], u4: [i8, u8], u8: [f_],
244    i_: [u1, i1], i1: [i2], i2: [i4], i4: [i8], i8: [f_],
245    f_: [bf, f2, c_], bf: [f4], f2: [f4], f4: [f8, c4], f8: [c8],
246    c_: [c4], c4: [c8], c8: [],
247  }
248
249def _make_lattice_upper_bounds():
250  lattice = _type_promotion_lattice()
251  upper_bounds = {node: {node} for node in lattice}
252  for n in lattice:
253    while True:
254      new_upper_bounds = set().union(*(lattice[b] for b in upper_bounds[n]))
255      if n in new_upper_bounds:
256        raise ValueError(f"cycle detected in type promotion lattice for node {n}")
257      if new_upper_bounds.issubset(upper_bounds[n]):
258        break
259      upper_bounds[n] |= new_upper_bounds
260  return upper_bounds
261_lattice_upper_bounds = _make_lattice_upper_bounds()
262
263@functools.lru_cache(512)  # don't use util.memoize because there is no X64 dependence.
264def _least_upper_bound(*nodes):
265  # This function computes the least upper bound of a set of nodes N within a partially
266  # ordered set defined by the lattice generated above.
267  # Given a partially ordered set S, let the set of upper bounds of n ∈ S be
268  #   UB(n) ≡ {m ∈ S | n ≤ m}
269  # Further, for a set of nodes N ⊆ S, let the set of common upper bounds be given by
270  #   CUB(N) ≡ {a ∈ S | ∀ b ∈ N: a ∈ UB(b)}
271  # Then the least upper bound of N is defined as
272  #   LUB(N) ≡ {c ∈ CUB(N) | ∀ d ∈ CUB(N), c ≤ d}
273  # The definition of an upper bound implies that c ≤ d if and only if d ∈ UB(c),
274  # so the LUB can be expressed:
275  #   LUB(N) = {c ∈ CUB(N) | ∀ d ∈ CUB(N): d ∈ UB(c)}
276  # or, equivalently:
277  #   LUB(N) = {c ∈ CUB(N) | CUB(N) ⊆ UB(c)}
278  # By definition, LUB(N) has a cardinality of 1 for a partially ordered set.
279  # Note a potential algorithmic shortcut: from the definition of CUB(N), we have
280  #   ∀ c ∈ N: CUB(N) ⊆ UB(c)
281  # So if N ∩ CUB(N) is nonempty, if follows that LUB(N) = N ∩ CUB(N).
282  N = set(nodes)
283  UB = _lattice_upper_bounds
284  CUB = set.intersection(*(UB[n] for n in N))
285  LUB = (CUB & N) or {c for c in CUB if CUB.issubset(UB[c])}
286  if len(LUB) == 1:
287    return LUB.pop()
288  else:
289    raise ValueError(f"{nodes} do not have a unique least upper bound.")
290
291def promote_types(a, b):
292  """Returns the type to which a binary operation should cast its arguments.
293
294  For details of JAX's type promotion semantics, see :ref:`type-promotion`.
295
296  Args:
297    a: a :class:`numpy.dtype` or a dtype specifier.
298    b: a :class:`numpy.dtype` or a dtype specifier.
299
300  Returns:
301    A :class:`numpy.dtype` object.
302  """
303  a = a if any(a is t for t in _weak_types) else np.dtype(a)
304  b = b if any(b is t for t in _weak_types) else np.dtype(b)
305  return np.dtype(_least_upper_bound(a, b))
306
307def is_weakly_typed(x):
308  try:
309    return x.aval.weak_type
310  except AttributeError:
311    return type(x) in _weak_types
312
313def is_python_scalar(x):
314  try:
315    return x.aval.weak_type and np.ndim(x) == 0
316  except AttributeError:
317    return type(x) in python_scalar_dtypes
318
319def dtype(x):
320  if type(x) in python_scalar_dtypes:
321    return python_scalar_dtypes[type(x)]
322  return np.result_type(x)
323
324def result_type(*args):
325  """Convenience function to apply Numpy argument dtype promotion."""
326   # TODO(jakevdp): propagate weak_type to the result.
327  if len(args) < 2:
328    return canonicalize_dtype(dtype(args[0]))
329  # TODO(jakevdp): propagate weak_type to the result when necessary.
330  return canonicalize_dtype(_least_upper_bound(*{_jax_type(arg) for arg in args}))
331