1# Copyright 2019 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15# Array type functions. 16# 17# JAX dtypes differ from NumPy in both: 18# a) their type promotion rules, and 19# b) the set of supported types (e.g., bfloat16), 20# so we need our own implementation that deviates from NumPy in places. 21 22 23from distutils.util import strtobool 24import functools 25import os 26from typing import Dict 27 28import numpy as np 29 30from ._src import util 31from .config import flags 32from .lib import xla_client 33 34from ._src import traceback_util 35traceback_util.register_exclusion(__file__) 36 37FLAGS = flags.FLAGS 38flags.DEFINE_bool('jax_enable_x64', 39 strtobool(os.getenv('JAX_ENABLE_X64', 'False')), 40 'Enable 64-bit types to be used.') 41 42# bfloat16 support 43bfloat16 = xla_client.bfloat16 44_bfloat16_dtype = np.dtype(bfloat16) 45 46# Default types. 47 48bool_ = np.bool_ 49int_ = np.int64 50float_ = np.float64 51complex_ = np.complex128 52 53# TODO(phawkins): change the above defaults to: 54# int_ = np.int32 55# float_ = np.float32 56# complex_ = np.complex64 57 58# Trivial vectorspace datatype needed for tangent values of int/bool primals 59float0 = np.dtype([('float0', np.void, 0)]) 60 61_dtype_to_32bit_dtype = { 62 np.dtype('int64'): np.dtype('int32'), 63 np.dtype('uint64'): np.dtype('uint32'), 64 np.dtype('float64'): np.dtype('float32'), 65 np.dtype('complex128'): np.dtype('complex64'), 66} 67 68@util.memoize 69def canonicalize_dtype(dtype): 70 """Convert from a dtype to a canonical dtype based on FLAGS.jax_enable_x64.""" 71 if isinstance(dtype, str) and dtype == "bfloat16": 72 dtype = bfloat16 73 try: 74 dtype = np.dtype(dtype) 75 except TypeError as e: 76 raise TypeError(f'dtype {dtype!r} not understood') from e 77 78 if FLAGS.jax_enable_x64: 79 return dtype 80 else: 81 return _dtype_to_32bit_dtype.get(dtype, dtype) 82 83 84# Default dtypes corresponding to Python scalars. 85python_scalar_dtypes = { 86 bool: np.dtype(bool_), 87 int: np.dtype(int_), 88 float: np.dtype(float_), 89 complex: np.dtype(complex_), 90 float0: float0 91} 92 93def scalar_type_of(x): 94 typ = dtype(x) 95 if np.issubdtype(typ, np.bool_): 96 return bool 97 elif np.issubdtype(typ, np.integer): 98 return int 99 elif np.issubdtype(typ, np.floating): 100 return float 101 elif np.issubdtype(typ, np.complexfloating): 102 return complex 103 else: 104 raise TypeError("Invalid scalar value {}".format(x)) 105 106def coerce_to_array(x): 107 """Coerces a scalar or NumPy array to an np.array. 108 109 Handles Python scalar type promotion according to JAX's rules, not NumPy's 110 rules. 111 """ 112 dtype = python_scalar_dtypes.get(type(x), None) 113 return np.array(x, dtype) if dtype else np.array(x) 114 115iinfo = np.iinfo 116 117class finfo(np.finfo): 118 __doc__ = np.finfo.__doc__ 119 _finfo_cache: Dict[np.dtype, np.finfo] = {} 120 @staticmethod 121 def _bfloat16_finfo(): 122 def float_to_str(f): 123 return "%12.4e" % float(f) 124 125 bfloat16 = _bfloat16_dtype.type 126 tiny = float.fromhex("0x1p-126") 127 resolution = 0.01 128 eps = float.fromhex("0x1p-7") 129 epsneg = float.fromhex("0x1p-8") 130 max = float.fromhex("0x1.FEp127") 131 132 obj = object.__new__(np.finfo) 133 obj.dtype = _bfloat16_dtype 134 obj.bits = 16 135 obj.eps = bfloat16(eps) 136 obj.epsneg = bfloat16(epsneg) 137 obj.machep = -7 138 obj.negep = -8 139 obj.max = bfloat16(max) 140 obj.min = bfloat16(-max) 141 obj.nexp = 8 142 obj.nmant = 7 143 obj.iexp = obj.nexp 144 obj.precision = 2 145 obj.resolution = bfloat16(resolution) 146 obj.tiny = bfloat16(tiny) 147 obj.machar = None # np.core.getlimits.MachArLike does not support bfloat16. 148 149 obj._str_tiny = float_to_str(tiny) 150 obj._str_max = float_to_str(max) 151 obj._str_epsneg = float_to_str(epsneg) 152 obj._str_eps = float_to_str(eps) 153 obj._str_resolution = float_to_str(resolution) 154 return obj 155 156 def __new__(cls, dtype): 157 if isinstance(dtype, str) and dtype == 'bfloat16' or dtype == _bfloat16_dtype: 158 if _bfloat16_dtype not in cls._finfo_cache: 159 cls._finfo_cache[_bfloat16_dtype] = cls._bfloat16_finfo() 160 return cls._finfo_cache[_bfloat16_dtype] 161 return super().__new__(cls, dtype) 162 163def _issubclass(a, b): 164 """Determines if ``a`` is a subclass of ``b``. 165 166 Similar to issubclass, but returns False instead of an exception if `a` is not 167 a class. 168 """ 169 try: 170 return issubclass(a, b) 171 except TypeError: 172 return False 173 174def issubdtype(a, b): 175 if a == bfloat16: 176 if isinstance(b, np.dtype): 177 return b == _bfloat16_dtype 178 else: 179 return b in [bfloat16, np.floating, np.inexact, np.number] 180 if not _issubclass(b, np.generic): 181 # Workaround for JAX scalar types. NumPy's issubdtype has a backward 182 # compatibility behavior for the second argument of issubdtype that 183 # interacts badly with JAX's custom scalar types. As a workaround, 184 # explicitly cast the second argument to a NumPy type object. 185 b = np.dtype(b).type 186 return np.issubdtype(a, b) 187 188can_cast = np.can_cast 189issubsctype = np.issubsctype 190 191# Return the type holding the real part of the input type 192def dtype_real(typ): 193 if np.issubdtype(typ, np.complexfloating): 194 if typ == np.dtype('complex64'): 195 return np.dtype('float32') 196 elif typ == np.dtype('complex128'): 197 return np.dtype('float64') 198 else: 199 raise TypeError("Unknown complex floating type {}".format(typ)) 200 else: 201 return typ 202 203# Enumeration of all valid JAX types in order. 204_weak_types = [int, float, complex] 205_jax_types = [ 206 np.dtype('bool'), 207 np.dtype('uint8'), 208 np.dtype('uint16'), 209 np.dtype('uint32'), 210 np.dtype('uint64'), 211 np.dtype('int8'), 212 np.dtype('int16'), 213 np.dtype('int32'), 214 np.dtype('int64'), 215 np.dtype(bfloat16), 216 np.dtype('float16'), 217 np.dtype('float32'), 218 np.dtype('float64'), 219 np.dtype('complex64'), 220 np.dtype('complex128'), 221] + _weak_types 222 223def _jax_type(value): 224 """Return the jax type for a value or type.""" 225 # Note: `x in _weak_types` can return false positives due to dtype comparator overloading. 226 if any(value is typ for typ in _weak_types): 227 return value 228 dtype_ = dtype(value) 229 if is_weakly_typed(value): 230 pytype = type(dtype_.type(0).item()) 231 if pytype in _weak_types: 232 return pytype 233 return dtype_ 234 235def _type_promotion_lattice(): 236 """ 237 Return the type promotion lattice in the form of a DAG. 238 This DAG maps each type to its immediately higher type on the lattice. 239 """ 240 b1, u1, u2, u4, u8, i1, i2, i4, i8, bf, f2, f4, f8, c4, c8, i_, f_, c_ = _jax_types 241 return { 242 b1: [i_], 243 u1: [i2, u2], u2: [i4, u4], u4: [i8, u8], u8: [f_], 244 i_: [u1, i1], i1: [i2], i2: [i4], i4: [i8], i8: [f_], 245 f_: [bf, f2, c_], bf: [f4], f2: [f4], f4: [f8, c4], f8: [c8], 246 c_: [c4], c4: [c8], c8: [], 247 } 248 249def _make_lattice_upper_bounds(): 250 lattice = _type_promotion_lattice() 251 upper_bounds = {node: {node} for node in lattice} 252 for n in lattice: 253 while True: 254 new_upper_bounds = set().union(*(lattice[b] for b in upper_bounds[n])) 255 if n in new_upper_bounds: 256 raise ValueError(f"cycle detected in type promotion lattice for node {n}") 257 if new_upper_bounds.issubset(upper_bounds[n]): 258 break 259 upper_bounds[n] |= new_upper_bounds 260 return upper_bounds 261_lattice_upper_bounds = _make_lattice_upper_bounds() 262 263@functools.lru_cache(512) # don't use util.memoize because there is no X64 dependence. 264def _least_upper_bound(*nodes): 265 # This function computes the least upper bound of a set of nodes N within a partially 266 # ordered set defined by the lattice generated above. 267 # Given a partially ordered set S, let the set of upper bounds of n ∈ S be 268 # UB(n) ≡ {m ∈ S | n ≤ m} 269 # Further, for a set of nodes N ⊆ S, let the set of common upper bounds be given by 270 # CUB(N) ≡ {a ∈ S | ∀ b ∈ N: a ∈ UB(b)} 271 # Then the least upper bound of N is defined as 272 # LUB(N) ≡ {c ∈ CUB(N) | ∀ d ∈ CUB(N), c ≤ d} 273 # The definition of an upper bound implies that c ≤ d if and only if d ∈ UB(c), 274 # so the LUB can be expressed: 275 # LUB(N) = {c ∈ CUB(N) | ∀ d ∈ CUB(N): d ∈ UB(c)} 276 # or, equivalently: 277 # LUB(N) = {c ∈ CUB(N) | CUB(N) ⊆ UB(c)} 278 # By definition, LUB(N) has a cardinality of 1 for a partially ordered set. 279 # Note a potential algorithmic shortcut: from the definition of CUB(N), we have 280 # ∀ c ∈ N: CUB(N) ⊆ UB(c) 281 # So if N ∩ CUB(N) is nonempty, if follows that LUB(N) = N ∩ CUB(N). 282 N = set(nodes) 283 UB = _lattice_upper_bounds 284 CUB = set.intersection(*(UB[n] for n in N)) 285 LUB = (CUB & N) or {c for c in CUB if CUB.issubset(UB[c])} 286 if len(LUB) == 1: 287 return LUB.pop() 288 else: 289 raise ValueError(f"{nodes} do not have a unique least upper bound.") 290 291def promote_types(a, b): 292 """Returns the type to which a binary operation should cast its arguments. 293 294 For details of JAX's type promotion semantics, see :ref:`type-promotion`. 295 296 Args: 297 a: a :class:`numpy.dtype` or a dtype specifier. 298 b: a :class:`numpy.dtype` or a dtype specifier. 299 300 Returns: 301 A :class:`numpy.dtype` object. 302 """ 303 a = a if any(a is t for t in _weak_types) else np.dtype(a) 304 b = b if any(b is t for t in _weak_types) else np.dtype(b) 305 return np.dtype(_least_upper_bound(a, b)) 306 307def is_weakly_typed(x): 308 try: 309 return x.aval.weak_type 310 except AttributeError: 311 return type(x) in _weak_types 312 313def is_python_scalar(x): 314 try: 315 return x.aval.weak_type and np.ndim(x) == 0 316 except AttributeError: 317 return type(x) in python_scalar_dtypes 318 319def dtype(x): 320 if type(x) in python_scalar_dtypes: 321 return python_scalar_dtypes[type(x)] 322 return np.result_type(x) 323 324def result_type(*args): 325 """Convenience function to apply Numpy argument dtype promotion.""" 326 # TODO(jakevdp): propagate weak_type to the result. 327 if len(args) < 2: 328 return canonicalize_dtype(dtype(args[0])) 329 # TODO(jakevdp): propagate weak_type to the result when necessary. 330 return canonicalize_dtype(_least_upper_bound(*{_jax_type(arg) for arg in args})) 331