1# Copyright 2018 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15# pytype: skip-file 16""" 17Implements the NumPy API, using the primitives in :mod:`jax.lax`. 18 19NumPy operations are implemented in Python in terms of the primitive operations 20in :mod:`jax.lax`. Since NumPy operations are not primitive and instead are 21implemented in terms of :mod:`jax.lax` operations, we do not need to define 22transformation rules such as gradient or batching rules. Instead, 23transformations for NumPy primitives can be derived from the transformation 24rules for the underlying :code:`lax` primitives. 25""" 26 27import builtins 28import collections 29import operator 30import os 31import types 32from typing import Sequence, FrozenSet, Optional, Tuple, Union, Iterable, cast 33from textwrap import dedent as _dedent 34import warnings 35 36import numpy as np 37import opt_einsum 38 39import jax 40from jax import jit, custom_jvp 41from .vectorize import vectorize 42from .util import _wraps 43from jax import core 44from jax import dtypes 45from jax.core import UnshapedArray, ShapedArray, ConcreteArray, canonicalize_shape 46from jax.config import flags, config 47from jax.interpreters.xla import DeviceArray, _DeviceArray, _CppDeviceArray 48from jax.interpreters.masking import Poly 49from jax import lax 50from jax._src.lax.lax import _device_put_raw 51from jax import ops 52from jax._src.util import (partial, unzip2, prod as _prod, subvals, safe_zip, 53 canonicalize_axis as _canonicalize_axis) 54from jax.tree_util import tree_leaves, tree_flatten 55 56FLAGS = flags.FLAGS 57flags.DEFINE_enum( 58 'jax_numpy_rank_promotion', os.getenv('JAX_NUMPY_RANK_PROMOTION', 'allow'), 59 enum_values=['allow', 'warn', 'raise'], 60 help= 61 'Control NumPy-style automatic rank promotion broadcasting ' 62 '("allow", "warn", or "raise").') 63 64newaxis = None 65 66# Common docstring additions: 67 68_PRECISION_DOC = """\ 69In addition to the original NumPy arguments listed below, also supports 70``precision`` for extra control over matrix-multiplication precision 71on supported devices. ``precision`` may be set to ``None``, which means 72default precision for the backend, a ``lax.Precision`` enum value 73(``Precision.DEFAULT``, ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple 74of two ``lax.Precision`` enums indicating separate precision for each argument. 75""" 76 77# We replace some builtin names to follow Numpy's API, so we capture here. 78_abs = builtins.abs 79_all = builtins.all 80_any = builtins.any 81_max = builtins.max 82_min = builtins.min 83_sum = builtins.sum 84_divmod = builtins.divmod 85 86# NumPy constants 87 88pi = np.pi 89e = np.e 90euler_gamma = np.euler_gamma 91inf = np.inf 92NINF = np.NINF 93PZERO = np.PZERO 94NZERO = np.NZERO 95nan = np.nan 96 97# And some numpy utility functions 98set_printoptions = np.set_printoptions 99 100# We want isinstance(x, np.ndarray) checks in user code to work with the our 101# array-like types, including DeviceArray and UnshapedArray (i.e. the abstract 102# array base class). We can override the isinstance behavior directly, without 103# having the complexity of multiple inheritance on those classes, by defining 104# the ndarray class to have a metaclass with special __instancecheck__ behavior. 105_arraylike_types = (np.ndarray, UnshapedArray, DeviceArray) 106 107class _ArrayMeta(type(np.ndarray)): # type: ignore 108 """Metaclass for overriding ndarray isinstance checks.""" 109 110 def __instancecheck__(self, instance): 111 try: 112 return isinstance(instance.aval, _arraylike_types) 113 except AttributeError: 114 return isinstance(instance, _arraylike_types) 115 116class ndarray(np.ndarray, metaclass=_ArrayMeta): 117 dtype: np.dtype 118 shape: Tuple[int, ...] 119 size: int 120 121 def __init__(shape, dtype=None, buffer=None, offset=0, strides=None, 122 order=None): 123 raise TypeError("jax.numpy.ndarray() should not be instantiated explicitly." 124 " Use jax.numpy.array, or jax.numpy.zeros instead.") 125 126 127iscomplexobj = np.iscomplexobj 128 129shape = _shape = np.shape 130ndim = _ndim = np.ndim 131size = np.size 132_dtype = dtypes.result_type 133 134# At present JAX doesn't have a reason to distinguish between scalars and arrays 135# in its object system. Further, we want JAX scalars to have the same type 136# promotion behaviors as JAX arrays. Rather than introducing a new type of JAX 137# scalar object with JAX promotion behaviors, instead we make the JAX scalar 138# types return JAX arrays when instantiated. 139 140class _ScalarMeta(type): 141 def __hash__(self): 142 return hash(self.dtype.type) 143 144 def __eq__(self, other): 145 return id(self) == id(other) or self.dtype.type == other 146 147 def __ne__(self, other): 148 return not (self == other) 149 150 def __call__(self, x): 151 return array(x, dtype=self.dtype) 152 153def _make_scalar_type(np_scalar_type): 154 return _ScalarMeta(np_scalar_type.__name__, (object,), 155 {"dtype": np.dtype(np_scalar_type)}) 156 157bool_ = _make_scalar_type(np.bool_) 158uint8 = _make_scalar_type(np.uint8) 159uint16 = _make_scalar_type(np.uint16) 160uint32 = _make_scalar_type(np.uint32) 161uint64 = _make_scalar_type(np.uint64) 162int8 = _make_scalar_type(np.int8) 163int16 = _make_scalar_type(np.int16) 164int32 = _make_scalar_type(np.int32) 165int64 = _make_scalar_type(np.int64) 166bfloat16 = _make_scalar_type(dtypes.bfloat16) 167float16 = _make_scalar_type(np.float16) 168float32 = single = _make_scalar_type(np.float32) 169float64 = double = _make_scalar_type(np.float64) 170complex64 = csingle = _make_scalar_type(np.complex64) 171complex128 = cdouble = _make_scalar_type(np.complex128) 172 173int_ = int32 if dtypes.int_ == np.int32 else int64 174float_ = float32 if dtypes.float_ == np.float32 else float64 175complex_ = complex64 if dtypes.complex_ == np.complex64 else complex128 176 177number = np.number 178inexact = np.inexact 179complexfloating = np.complexfloating 180floating = np.floating 181integer = np.integer 182signedinteger = np.signedinteger 183unsignedinteger = np.unsignedinteger 184 185flexible = np.flexible 186character = np.character 187object_ = np.object_ 188 189iinfo = dtypes.iinfo 190finfo = dtypes.finfo 191 192dtype = np.dtype 193can_cast = dtypes.can_cast 194issubsctype = dtypes.issubsctype 195promote_types = dtypes.promote_types 196 197ComplexWarning = np.ComplexWarning 198 199array_str = np.array_str 200array_repr = np.array_repr 201 202save = np.save 203savez = np.savez 204load = np.load 205 206 207### utility functions 208 209_DEFAULT_TYPEMAP = { 210 np.bool_: bool_, 211 np.int_: int_, 212 np.float_: float_, 213 np.complex_: complex_ 214} 215 216_INT_DTYPES = { 217 16: np.int16, 218 32: np.int32, 219 64: np.int64, 220} 221 222def _np_array(obj, dtype=None, **kwargs): 223 """Return a properly-typed numpy array. 224 225 `_np_array(obj, **kwds)` is equivalent to `np.array(obj, **kwds)`, with the 226 exception that when obj.dtype is not defined and dtype is not specified, it 227 uses Jax's default dtypes. 228 """ 229 arr = np.array(obj, dtype=dtype, **kwargs) 230 obj_dtype = getattr(obj, 'dtype', None) 231 arr_dtype = np.dtype(arr.dtype).type 232 if dtype is None and obj_dtype is None and arr_dtype in _DEFAULT_TYPEMAP: 233 arr = arr.astype(_DEFAULT_TYPEMAP[arr_dtype]) 234 return arr 235 236_np_asarray = partial(_np_array, copy=False) 237 238def _promote_shapes(fun_name, *args): 239 """Prepend implicit leading singleton dimensions for Numpy broadcasting.""" 240 if len(args) < 2: 241 return args 242 else: 243 shapes = [shape(arg) for arg in args] 244 nonscalar_ranks = [len(shp) for shp in shapes if shp] 245 if not nonscalar_ranks or len(set(nonscalar_ranks)) == 1: 246 return args 247 else: 248 if FLAGS.jax_numpy_rank_promotion != "allow": 249 _rank_promotion_warning_or_error(fun_name, shapes) 250 result_rank = len(lax.broadcast_shapes(*shapes)) 251 return [broadcast_to(arg, (1,) * (result_rank - len(shp)) + shp) 252 for arg, shp in zip(args, shapes)] 253 254def _rank_promotion_warning_or_error(fun_name, shapes): 255 if FLAGS.jax_numpy_rank_promotion == "warn": 256 msg = ("Following NumPy automatic rank promotion for {} on shapes {}. " 257 "Set the jax_numpy_rank_promotion config option to 'allow' to " 258 "disable this warning; for more information, see " 259 "https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.") 260 warnings.warn(msg.format(fun_name, ' '.join(map(str, shapes)))) 261 elif FLAGS.jax_numpy_rank_promotion == "raise": 262 msg = ("Operands could not be broadcast together for {} on shapes {} " 263 "and with the config option jax_numpy_rank_promotion='raise'. " 264 "For more information, see " 265 "https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.") 266 raise ValueError(msg.format(fun_name, ' '.join(map(str, shapes)))) 267 268def _promote_dtypes(*args): 269 """Convenience function to apply Numpy argument dtype promotion.""" 270 # TODO(dougalm,mattjj): This is a performance bottleneck. Consider memoizing. 271 if len(args) < 2: 272 return args 273 else: 274 to_dtype = result_type(*args) 275 return [lax.convert_element_type(x, to_dtype) for x in args] 276 277def _promote_dtypes_inexact(*args): 278 """Convenience function to apply Numpy argument dtype promotion. 279 280 Promotes arguments to an inexact type.""" 281 to_dtype = _to_inexact_dtype(result_type(*args)) 282 return [lax.convert_element_type(x, to_dtype) for x in args] 283 284def _to_inexact_dtype(dtype): 285 """Promotes a dtype into an inexact dtype, if it is not already one.""" 286 return dtype if issubdtype(dtype, inexact) else promote_types(dtype, float_) 287 288def _complex_elem_type(dtype): 289 """Returns the float type of the real/imaginary parts of a complex dtype.""" 290 return np.abs(np.zeros((), dtype)).dtype 291 292def _result_dtype(op, *args): 293 """Compute result dtype of applying op to arguments with given dtypes.""" 294 args = [np.ones((0,) * ndim(arg), _dtype(arg)) for arg in args] 295 return _dtype(op(*args)) 296 297 298def _arraylike(x): return isinstance(x, ndarray) or isscalar(x) 299def _check_arraylike(fun_name, *args): 300 """Check if all args fit JAX's definition of arraylike (ndarray or scalar).""" 301 assert isinstance(fun_name, str), f"fun_name must be a string. Got {fun_name}" 302 if _any(not _arraylike(arg) for arg in args): 303 pos, arg = next((i, arg) for i, arg in enumerate(args) 304 if not _arraylike(arg)) 305 msg = "{} requires ndarray or scalar arguments, got {} at position {}." 306 raise TypeError(msg.format(fun_name, type(arg), pos)) 307 308def _check_no_float0s(fun_name, *args): 309 """Check if none of the args have dtype float0.""" 310 if _any(dtypes.dtype(arg) is dtypes.float0 for arg in args): 311 raise TypeError( 312 f"Called {fun_name} with a float0 array. " 313 "float0s do not support any operations by design because they " 314 "are not compatible with non-trivial vector spaces. No implicit dtype " 315 "conversion is done. You can use np.zeros_like(arr, dtype=np.float) " 316 "to cast a float0 array to a regular zeros array. \n" 317 "If you didn't expect to get a float0 you might have accidentally " 318 "taken a gradient with respect to an integer argument.") 319 320def _promote_args(fun_name, *args): 321 """Convenience function to apply Numpy argument shape and dtype promotion.""" 322 _check_arraylike(fun_name, *args) 323 _check_no_float0s(fun_name, *args) 324 return _promote_shapes(fun_name, *_promote_dtypes(*args)) 325 326def _promote_args_inexact(fun_name, *args): 327 """Convenience function to apply Numpy argument shape and dtype promotion. 328 329 Promotes non-inexact types to an inexact type.""" 330 _check_arraylike(fun_name, *args) 331 _check_no_float0s(fun_name, *args) 332 return _promote_shapes(fun_name, *_promote_dtypes_inexact(*args)) 333 334def _constant_like(x, const): 335 return np.array(const, dtype=_dtype(x)) 336 337### implementations of numpy functions in terms of lax 338 339@_wraps(np.fmin) 340def fmin(x1, x2): 341 return where((x1 < x2) | isnan(x2), x1, x2) 342 343@_wraps(np.fmax) 344def fmax(x1, x2): 345 return where((x1 > x2) | isnan(x2), x1, x2) 346 347@_wraps(np.issubdtype) 348def issubdtype(arg1, arg2): 349 return dtypes.issubdtype(arg1, arg2) 350 351@_wraps(np.isscalar) 352def isscalar(element): 353 return dtypes.is_python_scalar(element) or np.isscalar(element) 354 355iterable = np.iterable 356 357@_wraps(np.result_type) 358def result_type(*args): 359 return dtypes.result_type(*args) 360 361def _one_to_one_unop(numpy_fn, lax_fn, promote_to_inexact=False, lax_doc=False): 362 if promote_to_inexact: 363 fn = lambda x: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x)) 364 else: 365 fn = lambda x: lax_fn(*_promote_args(numpy_fn.__name__, x)) 366 if lax_doc: 367 doc = _dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip() 368 return _wraps(numpy_fn, lax_description=doc)(fn) 369 else: 370 return _wraps(numpy_fn)(fn) 371 372def _one_to_one_binop(numpy_fn, lax_fn, promote_to_inexact=False, lax_doc=False): 373 if promote_to_inexact: 374 fn = lambda x1, x2: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x1, x2)) 375 else: 376 fn = lambda x1, x2: lax_fn(*_promote_args(numpy_fn.__name__, x1, x2)) 377 if lax_doc: 378 doc = _dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip() 379 return _wraps(numpy_fn, lax_description=doc)(fn) 380 else: 381 return _wraps(numpy_fn)(fn) 382 383def _maybe_bool_binop(numpy_fn, lax_fn, bool_lax_fn, lax_doc=False): 384 def fn(x1, x2): 385 x1, x2 = _promote_args(numpy_fn.__name__, x1, x2) 386 return lax_fn(x1, x2) if x1.dtype != bool_ else bool_lax_fn(x1, x2) 387 return _wraps(numpy_fn)(fn) 388 if lax_doc: 389 doc = _dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip() 390 return _wraps(numpy_fn, lax_description=doc)(fn) 391 else: 392 return _wraps(numpy_fn)(fn) 393 394fabs = _one_to_one_unop(np.fabs, lax.abs, True) 395bitwise_not = _one_to_one_unop(np.bitwise_not, lax.bitwise_not) 396invert = _one_to_one_unop(np.invert, lax.bitwise_not) 397negative = _one_to_one_unop(np.negative, lax.neg) 398positive = _one_to_one_unop(np.positive, lambda x: x) 399 400floor = _one_to_one_unop(np.floor, lax.floor, True) 401ceil = _one_to_one_unop(np.ceil, lax.ceil, True) 402exp = _one_to_one_unop(np.exp, lax.exp, True) 403log = _one_to_one_unop(np.log, lax.log, True) 404expm1 = _one_to_one_unop(np.expm1, lax.expm1, True) 405log1p = _one_to_one_unop(np.log1p, lax.log1p, True) 406sin = _one_to_one_unop(np.sin, lax.sin, True) 407cos = _one_to_one_unop(np.cos, lax.cos, True) 408tan = _one_to_one_unop(np.tan, lax.tan, True) 409arcsin = _one_to_one_unop(np.arcsin, lax.asin, True) 410arccos = _one_to_one_unop(np.arccos, lax.acos, True) 411arctan = _one_to_one_unop(np.arctan, lax.atan, True) 412sinh = _one_to_one_unop(np.sinh, lax.sinh, True) 413cosh = _one_to_one_unop(np.cosh, lax.cosh, True) 414arcsinh = _one_to_one_unop(np.arcsinh, lax.asinh, True) 415tanh = _one_to_one_unop(np.tanh, lax.tanh, True) 416arcsinh = _one_to_one_unop(np.arcsinh, lax.asinh, True) 417arctanh = _one_to_one_unop(np.arctanh, lax.atanh, True) 418sqrt = _one_to_one_unop(np.sqrt, lax.sqrt, True) 419 420 421add = _maybe_bool_binop(np.add, lax.add, lax.bitwise_or) 422bitwise_and = _one_to_one_binop(np.bitwise_and, lax.bitwise_and) 423bitwise_or = _one_to_one_binop(np.bitwise_or, lax.bitwise_or) 424bitwise_xor = _one_to_one_binop(np.bitwise_xor, lax.bitwise_xor) 425left_shift = _one_to_one_binop(np.left_shift, lax.shift_left) 426equal = _one_to_one_binop(np.equal, lax.eq) 427multiply = _maybe_bool_binop(np.multiply, lax.mul, lax.bitwise_and) 428not_equal = _one_to_one_binop(np.not_equal, lax.ne) 429subtract = _one_to_one_binop(np.subtract, lax.sub) 430arctan2 = _one_to_one_binop(np.arctan2, lax.atan2, True) 431minimum = _one_to_one_binop(np.minimum, lax.min) 432maximum = _one_to_one_binop(np.maximum, lax.max) 433float_power = _one_to_one_binop(np.float_power, lax.pow, True) 434nextafter = _one_to_one_binop(np.nextafter, lax.nextafter, True, True) 435 436@_wraps(np.arccosh) 437def arccosh(x): 438 # Note: arccosh is multi-valued for complex input, and lax.acosh uses a different 439 # convention than np.arccosh. 440 out = lax.acosh(*_promote_args_inexact("arccosh", x)) 441 if issubdtype(out.dtype, np.complexfloating): 442 out = where(real(out) < 0, lax.neg(out), out) 443 return out 444 445def _comparison_op(numpy_fn, lax_fn): 446 def fn(x1, x2): 447 x1, x2 = _promote_args(numpy_fn.__name__, x1, x2) 448 # Comparison on complex types are defined as a lexicographic ordering on 449 # the (real, imag) pair. 450 if issubdtype(_dtype(x1), complexfloating): 451 rx = lax.real(x1) 452 ry = lax.real(x2) 453 return lax.select(lax.eq(rx, ry), lax_fn(lax.imag(x1), lax.imag(x2)), 454 lax_fn(rx, ry)) 455 return lax_fn(x1, x2) 456 return _wraps(numpy_fn)(fn) 457 458greater_equal = _comparison_op(np.greater_equal, lax.ge) 459greater = _comparison_op(np.greater, lax.gt) 460less_equal = _comparison_op(np.less_equal, lax.le) 461less = _comparison_op(np.less, lax.lt) 462 463 464def _logical_op(np_op, bitwise_op): 465 @_wraps(np_op, update_doc=False) 466 def op(*args): 467 zero = lambda x: lax.full_like(x, shape=(), fill_value=0) 468 args = (x if issubdtype(_dtype(x), bool_) else lax.ne(x, zero(x)) 469 for x in args) 470 return bitwise_op(*_promote_args(np_op.__name__, *args)) 471 return op 472 473logical_and = _logical_op(np.logical_and, lax.bitwise_and) 474logical_not = _logical_op(np.logical_not, lax.bitwise_not) 475logical_or = _logical_op(np.logical_or, lax.bitwise_or) 476logical_xor = _logical_op(np.logical_xor, lax.bitwise_xor) 477 478 479@_wraps(np.right_shift) 480def right_shift(x1, x2): 481 x1, x2 = _promote_args(np.right_shift.__name__, x1, x2) 482 lax_fn = lax.shift_right_logical if \ 483 np.issubdtype(x1.dtype, np.unsignedinteger) else lax.shift_right_arithmetic 484 return lax_fn(x1, x2) 485 486 487@_wraps(np.absolute) 488def absolute(x): 489 _check_arraylike('absolute', x) 490 dt = _dtype(x) 491 return x if dt == bool_ or issubdtype(dt, unsignedinteger) else lax.abs(x) 492abs = _wraps(np.abs)(absolute) 493 494 495@_wraps(np.rint) 496def rint(x): 497 _check_arraylike('rint', x) 498 dtype = _dtype(x) 499 if issubdtype(dtype, integer): 500 return lax.convert_element_type(x, float_) 501 if issubdtype(dtype, complexfloating): 502 return lax.complex(rint(lax.real(x)), rint(lax.imag(x))) 503 return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN) 504 505 506@_wraps(np.sign) 507def sign(x): 508 _check_arraylike('sign', x) 509 dtype = _dtype(x) 510 if issubdtype(dtype, complexfloating): 511 re = lax.real(x) 512 return lax.complex( 513 lax.sign(where(re != 0, re, lax.imag(x))), _constant_like(re, 0)) 514 return lax.sign(x) 515 516 517@_wraps(np.copysign) 518def copysign(x1, x2): 519 x1, x2 = _promote_args_inexact("copysign", x1, x2) 520 if issubdtype(_dtype(x1), complexfloating): 521 raise TypeError("copysign does not support complex-valued inputs") 522 return where(signbit(x2), -lax.abs(x1), lax.abs(x1)) 523 524 525@_wraps(np.true_divide) 526def true_divide(x1, x2): 527 x1, x2 = _promote_args_inexact("true_divide", x1, x2) 528 return lax.div(x1, x2) 529 530divide = true_divide 531 532@_wraps(np.floor_divide) 533def floor_divide(x1, x2): 534 x1, x2 = _promote_args("floor_divide", x1, x2) 535 dtype = _dtype(x1) 536 if issubdtype(dtype, integer): 537 quotient = lax.div(x1, x2) 538 select = logical_and(lax.sign(x1) != lax.sign(x2), lax.rem(x1, x2) != 0) 539 # TODO(mattjj): investigate why subtracting a scalar was causing promotion 540 return where(select, quotient - np.array(1, _dtype(quotient)), quotient) 541 elif issubdtype(dtype, complexfloating): 542 x1r = lax.real(x1) 543 x1i = lax.imag(x1) 544 x2r = lax.real(x2) 545 x2i = lax.imag(x2) 546 which = lax.ge(lax.abs(x2r), lax.abs(x2i)) 547 rat1 = where(which, lax._const(x2i, 1), lax.div(x2r, x2i)) 548 rat2 = where(which, lax.div(x2i, x2r), lax._const(x2i, 1)) 549 out = lax.floor(lax.div(lax.add(lax.mul(x1r, rat1), lax.mul(x1i, rat2)), 550 lax.add(lax.mul(x2r, rat1), lax.mul(x2i, rat2)))) 551 return lax.convert_element_type(out, dtype) 552 else: 553 return _float_divmod(x1, x2)[0] 554 555 556@_wraps(np.divmod) 557def divmod(x1, x2): 558 x1, x2 = _promote_args("divmod", x1, x2) 559 if issubdtype(_dtype(x1), integer): 560 return floor_divide(x1, x2), remainder(x1, x2) 561 else: 562 return _float_divmod(x1, x2) 563 564 565def _float_divmod(x1, x2): 566 # see float_divmod in floatobject.c of CPython 567 mod = lax.rem(x1, x2) 568 div = lax.div(lax.sub(x1, mod), x2) 569 570 ind = lax.bitwise_and(mod != 0, lax.sign(x2) != lax.sign(mod)) 571 mod = lax.select(ind, mod + x2, mod) 572 div = lax.select(ind, div - _constant_like(div, 1), div) 573 574 return lax.round(div), mod 575 576 577@_wraps(np.power) 578def power(x1, x2): 579 # Special case for concrete integer scalars: use binary exponentiation. 580 # Using lax.pow may be imprecise for floating-point values; the goal of this 581 # code path is to make sure we end up with a precise output for the common 582 # pattern ``x ** 2`` or similar. 583 try: 584 x2 = core.concrete_or_error(operator.index, x2) 585 except (core.ConcretizationTypeError, TypeError): 586 pass 587 else: 588 return lax.integer_pow(x1, x2) 589 590 x1, x2 = _promote_args("power", x1, x2) 591 dtype = _dtype(x1) 592 if not issubdtype(dtype, integer): 593 return lax.pow(x1, x2) 594 595 # Integer power => use binary exponentiation. 596 597 # TODO(phawkins): add integer pow support to XLA. 598 bits = 6 # Anything more would overflow for any x1 > 1 599 acc = ones(shape(x1), dtype=dtype) 600 for _ in range(bits): 601 acc = where(lax.bitwise_and(x2, _constant_like(x2, 1)), 602 lax.mul(acc, x1), acc) 603 x1 = lax.mul(x1, x1) 604 x2 = lax.shift_right_logical(x2, _constant_like(x2, 1)) 605 return acc 606 607 608@custom_jvp 609@_wraps(np.logaddexp) 610def logaddexp(x1, x2): 611 x1, x2 = _promote_shapes("logaddexp", *_promote_dtypes_inexact(x1, x2)) 612 amax = lax.max(x1, x2) 613 delta = lax.sub(x1, x2) 614 return lax.select(isnan(delta), 615 lax.add(x1, x2), # NaNs or infinities of the same sign. 616 lax.add(amax, lax.log1p(lax.exp(-lax.abs(delta))))) 617 618@logaddexp.defjvp 619def _logaddexp_jvp(primals, tangents): 620 x1, x2 = primals 621 t1, t2 = tangents 622 x1, x2, t1, t2 = broadcast_arrays(x1, x2, t1, t2) 623 primal_out = logaddexp(x1, x2) 624 tangent_out = (t1 * exp(_replace_inf(x1) - _replace_inf(primal_out)) + 625 t2 * exp(_replace_inf(x2) - _replace_inf(primal_out))) 626 return primal_out, tangent_out 627 628def _replace_inf(x): 629 return lax.select(isposinf(x), zeros_like(x), x) 630 631 632@custom_jvp 633@_wraps(np.logaddexp2) 634def logaddexp2(x1, x2): 635 x1, x2 = _promote_shapes("logaddexp2", *_promote_dtypes_inexact(x1, x2)) 636 amax = lax.max(x1, x2) 637 delta = lax.sub(x1, x2) 638 return lax.select(isnan(delta), 639 lax.add(x1, x2), # NaNs or infinities of the same sign. 640 lax.add(amax, lax.div(lax.log1p(exp2(-lax.abs(delta))), 641 _constant_like(x1, np.log(2))))) 642@logaddexp2.defjvp 643def _logaddexp2_jvp(primals, tangents): 644 x1, x2 = primals 645 t1, t2 = tangents 646 x1, x2, t1, t2 = broadcast_arrays(x1, x2, t1, t2) 647 primal_out = logaddexp2(x1, x2) 648 tangent_out = (t1 * 2 ** (_replace_inf(x1) - _replace_inf(primal_out)) + 649 t2 * 2 ** (_replace_inf(x2) - _replace_inf(primal_out))) 650 return primal_out, tangent_out 651 652 653@_wraps(np.log2) 654def log2(x): 655 x, = _promote_dtypes_inexact(x) 656 return lax.div(lax.log(x), lax.log(_constant_like(x, 2))) 657 658 659@_wraps(np.log10) 660def log10(x): 661 x, = _promote_dtypes_inexact(x) 662 return lax.div(lax.log(x), lax.log(_constant_like(x, 10))) 663 664 665@_wraps(np.exp2) 666def exp2(x): 667 x, = _promote_dtypes_inexact(x) 668 return lax.exp(lax.mul(lax.log(_constant_like(x, 2)), x)) 669 670@_wraps(np.signbit) 671def signbit(x): 672 x, = _promote_shapes("signbit", x) 673 dtype = _dtype(x) 674 if issubdtype(dtype, integer): 675 return lax.lt(x, _constant_like(x, 0)) 676 elif issubdtype(dtype, bool_): 677 return full_like(x, False, dtype=bool_) 678 elif not issubdtype(dtype, floating): 679 raise ValueError( 680 "jax.numpy.signbit is not well defined for %s" % dtype) 681 682 # TPU supports BF16 but not S16 types, so as a workaround, convert BF16 to 683 # F32. 684 if dtype == bfloat16: 685 dtype = float32 686 x = lax.convert_element_type(x, float32) 687 688 info = finfo(dtype) 689 if info.bits not in _INT_DTYPES: 690 raise NotImplementedError( 691 "jax.numpy.signbit only supports 16, 32, and 64-bit types.") 692 int_type = _INT_DTYPES[info.bits] 693 x = lax.bitcast_convert_type(x, int_type) 694 return lax.convert_element_type(x >> (info.nexp + info.nmant), np.bool_) 695 696 697@_wraps(np.trapz) 698def trapz(y, x=None, dx=1.0, axis: int = -1): 699 _check_arraylike('trapz', y) 700 y = moveaxis(y, axis, -1) 701 if x is not None: 702 if ndim(x) == 1: 703 dx = diff(x) 704 else: 705 dx = moveaxis(diff(x, axis=axis), axis, -1) 706 return 0.5 * (dx * (y[..., 1:] + y[..., :-1])).sum(-1) 707 708 709@_wraps(np.trunc) 710def trunc(x): 711 _check_arraylike('trunc', x) 712 return where(lax.lt(x, lax._const(x, 0)), ceil(x), floor(x)) 713 714 715def _conv(x, y, mode, op, precision): 716 if issubdtype(_dtype(x), complexfloating) or issubdtype(_dtype(y), complexfloating): 717 raise NotImplementedError(f"{op}() does not support complex inputs") 718 if ndim(x) != 1 or ndim(y) != 1: 719 raise ValueError(f"{op}() only support 1-dimensional inputs.") 720 x, y = _promote_dtypes_inexact(x, y) 721 if len(x) == 0 or len(y) == 0: 722 raise ValueError(f"{op}: inputs cannot be empty, got shapes {x.shape} and {y.shape}.") 723 724 out_order = slice(None) 725 if len(x) < len(y): 726 x, y = y, x 727 if op == "correlate": 728 out_order = slice(None, None, -1) 729 if op == 'convolve': 730 y = y[::-1] 731 732 if mode == 'valid': 733 padding = [(0, 0)] 734 elif mode == 'same': 735 padding = [(y.shape[0] // 2, y.shape[0] - y.shape[0] // 2 - 1)] 736 elif mode == 'full': 737 padding = [(y.shape[0] - 1, y.shape[0] - 1)] 738 else: 739 raise ValueError("mode must be one of ['full', 'same', 'valid']") 740 741 result = lax.conv_general_dilated(x[None, None, :], y[None, None, :], (1,), 742 padding, precision=precision) 743 return result[0, 0, out_order] 744 745 746@_wraps(np.convolve, lax_description=_PRECISION_DOC) 747def convolve(a, v, mode='full', *, precision=None): 748 _check_arraylike("convolve", a, v) 749 return _conv(a, v, mode, 'convolve', precision) 750 751 752@_wraps(np.correlate, lax_description=_PRECISION_DOC) 753def correlate(a, v, mode='valid', *, precision=None): 754 _check_arraylike("correlate", a, v) 755 return _conv(a, v, mode, 'correlate', precision) 756 757 758def _normalize_float(x): 759 info = finfo(_dtype(x)) 760 cond = lax.abs(x) < info.tiny 761 x1 = where(cond, x * lax._const(x, 1 << info.nmant), x) 762 x2 = where(cond, lax._const(np.int32, -info.nmant), lax._const(np.int32, 0)) 763 int_type = _INT_DTYPES[info.bits] 764 return lax.bitcast_convert_type(x1, int_type), x2 765 766 767@_wraps(np.ldexp) 768@jit 769def ldexp(x1, x2): 770 dtype = dtypes.canonicalize_dtype(_result_dtype(np.ldexp, x1, x2)) 771 x1, x2 = _promote_shapes("ldexp", x1, x2) 772 x1 = lax.convert_element_type(x1, dtype) 773 774 info = finfo(dtype) 775 mask = (1 << info.nexp) - 1 776 bias = ((1 << info.nexp) - 1) >> 1 777 778 int_type = _INT_DTYPES[info.bits] 779 780 x, e = _normalize_float(x1) 781 x2 += e + ((x >> info.nmant) & mask) - bias 782 783 # find underflow/overflow before denormalization 784 underflow_cond = x2 < -(bias + info.nmant) 785 overflow_cond = x2 > bias 786 787 m = ones_like(x, dtype=dtype) 788 789 # denormals 790 cond = x2 < -bias + 1 791 x2 = where(cond, x2 + info.nmant, x2) 792 m = where(cond, m / (1 << info.nmant), m) 793 794 x2 = lax.convert_element_type(x2, np.int32) 795 x &= ~(mask << info.nmant) 796 x |= ((lax.convert_element_type(x2, int_type) + bias) << info.nmant) 797 798 x = lax.convert_element_type(m, dtype) * lax.bitcast_convert_type(x, dtype) 799 800 # underflow 801 x = where(underflow_cond, zeros_like(x, dtype=dtype), x) 802 # overflow 803 x = where(overflow_cond, lax.sign(x1) * full_like(x, np.inf), x) 804 # ldexp(x1, x2) = x1 for x1 = inf, -inf, nan, 0 805 return where(isinf(x1) | isnan(x1) | (x1 == 0), x1, x) 806 807 808@_wraps(np.frexp) 809@jit 810def frexp(x): 811 x = asarray(x) 812 if issubdtype(x.dtype, complexfloating): 813 raise TypeError("frexp does not support complex-valued inputs") 814 elif not issubdtype(x.dtype, floating): 815 x = lax.convert_element_type(x, float_) 816 817 dtype = _dtype(x) 818 info = finfo(dtype) 819 mask = (1 << info.nexp) - 1 820 bias = ((1 << info.nexp) - 1) >> 1 821 822 x1, x2 = _normalize_float(x) 823 x2 += ((x1 >> info.nmant) & mask) - bias + 1 824 x1 &= ~(mask << info.nmant) 825 x1 |= (bias - 1) << info.nmant 826 x1 = lax.bitcast_convert_type(x1, dtype) 827 828 cond = isinf(x) | isnan(x) | (x == 0) 829 x2 = where(cond, zeros_like(x2), x2) 830 return where(cond, x, x1), lax.convert_element_type(x2, int32) 831 832 833@_wraps(np.remainder) 834def remainder(x1, x2): 835 x1, x2 = _promote_args("remainder", x1, x2) 836 zero = _constant_like(x1, 0) 837 trunc_mod = lax.rem(x1, x2) 838 trunc_mod_not_zero = lax.ne(trunc_mod, zero) 839 do_plus = lax.bitwise_and( 840 lax.ne(lax.lt(trunc_mod, zero), lax.lt(x2, zero)), trunc_mod_not_zero) 841 return lax.select(do_plus, lax.add(trunc_mod, x2), trunc_mod) 842mod = _wraps(np.mod)(remainder) 843 844 845@_wraps(np.fmod) 846def fmod(x1, x2): 847 _check_arraylike("fmod", x1, x2) 848 if issubdtype(_dtype(x1, x2), integer): 849 x2 = where(x2 == 0, 1, x2) 850 return lax.rem(*_promote_args("fmod", x1, x2)) 851 852 853@_wraps(np.cbrt) 854def cbrt(x): 855 _check_arraylike("cbrt", x) 856 x, = _promote_dtypes_inexact(x) 857 return lax.sign(x) * power(lax.abs(x), _constant_like(x, 1. / 3.)) 858 859 860@_wraps(np.square) 861def square(x): 862 _check_arraylike("square", x) 863 return lax.integer_pow(x, 2) 864 865 866@_wraps(np.deg2rad) 867def deg2rad(x): 868 _check_arraylike("deg2rad", x) 869 x, = _promote_dtypes_inexact(x) 870 return lax.mul(x, lax._const(x, pi / 180)) 871 872 873@_wraps(np.rad2deg) 874def rad2deg(x): 875 _check_arraylike("rad2deg", x) 876 x, = _promote_dtypes_inexact(x) 877 return lax.mul(x, lax._const(x, 180 / pi)) 878 879 880degrees = rad2deg 881radians = deg2rad 882 883 884@_wraps(np.histogram_bin_edges) 885def histogram_bin_edges(a, bins=10, range=None, weights=None): 886 if isinstance(bins, str): 887 raise NotImplementedError("string values for `bins` not implemented.") 888 a = ravel(a) 889 b = asarray(bins) 890 if b.ndim == 1: 891 return b 892 if range is None: 893 range = (a.min(), a.max()) 894 assert len(range) == 2 895 range = asarray(range) 896 range = (where(ptp(range) == 0, range[0] - 0.5, range[0]), 897 where(ptp(range) == 0, range[1] + 0.5, range[1])) 898 dtype = _dtype(a) 899 if issubdtype(dtype, integer): 900 dtype = promote_types(dtype, float32) 901 return linspace(range[0], range[1], bins + 1, dtype=dtype) 902 903 904@_wraps(np.histogram) 905def histogram(a, bins=10, range=None, weights=None, density=None): 906 if weights is not None and a.shape != weights.shape: 907 raise ValueError("weights should have the same shape as a.") 908 a = ravel(a) 909 if weights is not None: 910 weights = ravel(weights) 911 else: 912 weights = ones_like(a) 913 bin_edges = histogram_bin_edges(a, bins, range, weights) 914 bin_idx = searchsorted(bin_edges, a, side='right') 915 bin_idx = where(a == bin_edges[-1], len(bin_edges) - 1, bin_idx) 916 counts = bincount(bin_idx, weights, length=len(bin_edges))[1:] 917 if density: 918 bin_widths = diff(bin_edges) 919 counts = counts / bin_widths / counts.sum() 920 return counts, bin_edges 921 922@_wraps(np.histogram2d) 923def histogram2d(x, y, bins=10, range=None, weights=None, density=None): 924 925 try: 926 N = len(bins) 927 except TypeError: 928 N = 1 929 930 if N != 1 and N != 2: 931 x_edges = y_edges = asarray(bins) 932 bins = [x_edges, y_edges] 933 934 sample = transpose(asarray([x, y])) 935 hist, edges = histogramdd(sample, bins, range, weights, density) 936 return hist, edges[0], edges[1] 937 938@_wraps(np.histogramdd) 939def histogramdd(sample, bins=10, range=None, weights=None, density=None): 940 _check_arraylike("histogramdd", sample) 941 N, D = shape(sample) 942 943 if weights is not None and weights.shape != (N,): 944 raise ValueError("should have one weight for each sample.") 945 946 try: 947 num_bins = len(bins) 948 if num_bins != D: 949 raise ValueError("should be a bin for each dimension.") 950 except TypeError: 951 # when bin_size is integer, the same bin is used for each dimension 952 bins = D * [bins] 953 954 bin_idx_by_dim = D*[None] 955 nbins = np.empty(D, int) 956 bin_edges_by_dim = D*[None] 957 dedges = D*[None] 958 959 for i in builtins.range(D): 960 bin_edges = histogram_bin_edges(sample[:, i], bins[i], range, weights) 961 bin_idx = searchsorted(bin_edges, sample[:, i], side='right') 962 bin_idx = where(sample[:, i] == bin_edges[-1], bin_idx - 1, bin_idx) 963 bin_idx_by_dim[i] = bin_idx 964 nbins[i] = len(bin_edges) + 1 965 bin_edges_by_dim[i] = bin_edges 966 dedges[i] = diff(bin_edges_by_dim[i]) 967 968 xy = ravel_multi_index(bin_idx_by_dim, nbins, mode='clip') 969 hist = bincount(xy, weights, length=nbins.prod()) 970 hist = reshape(hist, nbins) 971 core = D*(slice(1, -1),) 972 hist = hist[core] 973 974 if density: 975 s = sum(hist) 976 for i in builtins.range(D): 977 _shape = np.ones(D, int) 978 _shape[i] = nbins[i] - 2 979 hist = hist / reshape(dedges[i], _shape) 980 981 hist /= s 982 983 return hist, bin_edges_by_dim 984 985@_wraps(np.heaviside) 986def heaviside(x1, x2): 987 _check_arraylike("heaviside", x1, x2) 988 x1, x2 = _promote_dtypes_inexact(x1, x2) 989 zero = lax._const(x1, 0) 990 return where(lax.lt(x1, zero), zero, 991 where(lax.gt(x1, zero), lax._const(x1, 1), x2)) 992 993 994@_wraps(np.hypot) 995def hypot(x1, x2): 996 _check_arraylike("hypot", x1, x2) 997 x1, x2 = _promote_dtypes_inexact(x1, x2) 998 x1 = lax.abs(x1) 999 x2 = lax.abs(x2) 1000 x1, x2 = maximum(x1, x2), minimum(x1, x2) 1001 return lax.select(x1 == 0, x1, x1 * lax.sqrt(1 + lax.square(lax.div(x2, lax.select(x1 == 0, ones_like(x1), x1))))) 1002 1003 1004@_wraps(np.reciprocal) 1005def reciprocal(x): 1006 _check_arraylike("reciprocal", x) 1007 x, = _promote_dtypes_inexact(x) 1008 return lax.integer_pow(x, -1) 1009 1010 1011@_wraps(np.sinc, update_doc=False) 1012def sinc(x): 1013 _check_arraylike("sinc", x) 1014 x, = _promote_dtypes_inexact(x) 1015 eq_zero = lax.eq(x, lax._const(x, 0)) 1016 pi_x = lax.mul(lax._const(x, pi), x) 1017 safe_pi_x = where(eq_zero, lax._const(x, 0), pi_x) 1018 return where(eq_zero, _sinc_maclaurin(0, pi_x), 1019 lax.div(lax.sin(safe_pi_x), safe_pi_x)) 1020 1021@partial(custom_jvp, nondiff_argnums=(0,)) 1022def _sinc_maclaurin(k, x): 1023 # compute the kth derivative of x -> sin(x)/x evaluated at zero (since we 1024 # compute the monomial term in the jvp rule) 1025 if k % 2: 1026 return lax.full_like(x, 0) 1027 else: 1028 return lax.full_like(x, (-1) ** (k // 2) / (k + 1)) 1029 1030@_sinc_maclaurin.defjvp 1031def _sinc_maclaurin_jvp(k, primals, tangents): 1032 (x,), (t,) = primals, tangents 1033 return _sinc_maclaurin(k, x), _sinc_maclaurin(k + 1, x) * t 1034 1035 1036@_wraps(np.transpose) 1037def transpose(a, axes=None): 1038 _check_arraylike("transpose", a) 1039 axes = np.arange(ndim(a))[::-1] if axes is None else axes 1040 return lax.transpose(a, axes) 1041 1042 1043@_wraps(np.rot90) 1044def rot90(m, k=1, axes=(0, 1)): 1045 _check_arraylike("rot90", m) 1046 ax1, ax2 = axes 1047 ax1 = _canonicalize_axis(ax1, ndim(m)) 1048 ax2 = _canonicalize_axis(ax2, ndim(m)) 1049 if ax1 == ax2: 1050 raise ValueError("Axes must be different") # same as numpy error 1051 k = k % 4 1052 if k == 0: 1053 return m 1054 elif k == 2: 1055 return flip(flip(m, ax1), ax2) 1056 else: 1057 perm = list(range(m.ndim)) 1058 perm[ax1], perm[ax2] = perm[ax2], perm[ax1] 1059 if k == 1: 1060 return transpose(flip(m, ax2), perm) 1061 else: 1062 return flip(transpose(m, perm), ax2) 1063 1064 1065@_wraps(np.flip) 1066def flip(m, axis: Optional[Union[int, Tuple[int, ...]]] = None): 1067 _check_arraylike("flip", m) 1068 if axis is None: 1069 return lax.rev(m, list(range(len(shape(m))))) 1070 return lax.rev(m, [_canonicalize_axis(axis, ndim(m))]) 1071 1072 1073@_wraps(np.fliplr) 1074def fliplr(m): 1075 return flip(m, 1) 1076 1077 1078@_wraps(np.flipud) 1079def flipud(m): 1080 return flip(m, 0) 1081 1082 1083@_wraps(np.conjugate) 1084def conjugate(x): 1085 _check_arraylike("conjugate", x) 1086 return lax.conj(x) if iscomplexobj(x) else x 1087conj = conjugate 1088 1089 1090@_wraps(np.imag) 1091def imag(val): 1092 _check_arraylike("imag", val) 1093 return lax.imag(val) if iscomplexobj(val) else zeros_like(val) 1094 1095 1096@_wraps(np.real) 1097def real(val): 1098 _check_arraylike("real", val) 1099 return lax.real(val) if iscomplexobj(val) else val 1100 1101 1102@_wraps(np.iscomplex) 1103def iscomplex(x): 1104 i = imag(x) 1105 return lax.ne(i, lax._const(i, 0)) 1106 1107@_wraps(np.isreal) 1108def isreal(x): 1109 i = imag(x) 1110 return lax.eq(i, lax._const(i, 0)) 1111 1112@_wraps(np.angle) 1113def angle(z): 1114 re = real(z) 1115 im = imag(z) 1116 dtype = _dtype(re) 1117 if not issubdtype(dtype, inexact) or ( 1118 issubdtype(_dtype(z), floating) and ndim(z) == 0): 1119 dtype = dtypes.canonicalize_dtype(float_) 1120 re = lax.convert_element_type(re, dtype) 1121 im = lax.convert_element_type(im, dtype) 1122 return lax.atan2(im, re) 1123 1124 1125@_wraps(np.diff) 1126def diff(a, n=1, axis: int = -1, prepend=None, append=None): 1127 _check_arraylike("diff", a) 1128 if n == 0: 1129 return a 1130 if n < 0: 1131 raise ValueError(f"order must be non-negative but got {n}") 1132 if ndim(a) == 0: 1133 raise ValueError(f"diff requires input that is at least one dimensional; got {a}") 1134 1135 nd = a.ndim 1136 1137 combined = [] 1138 if prepend is not None: 1139 _check_arraylike("diff", prepend) 1140 if isscalar(prepend): 1141 shape = list(a.shape) 1142 shape[axis] = 1 1143 prepend = broadcast_to(prepend, tuple(shape)) 1144 combined.append(prepend) 1145 1146 combined.append(a) 1147 1148 if append is not None: 1149 _check_arraylike("diff", append) 1150 if isscalar(append): 1151 shape = list(a.shape) 1152 shape[axis] = 1 1153 append = broadcast_to(append, tuple(shape)) 1154 combined.append(append) 1155 1156 if len(combined) > 1: 1157 a = concatenate(combined, axis) 1158 1159 slice1 = [slice(None)] * nd 1160 slice2 = [slice(None)] * nd 1161 slice1[axis] = slice(1, None) 1162 slice2[axis] = slice(None, -1) 1163 slice1_tuple = tuple(slice1) 1164 slice2_tuple = tuple(slice2) 1165 1166 op = not_equal if a.dtype == np.bool_ else subtract 1167 for _ in range(n): 1168 a = op(a[slice1_tuple], a[slice2_tuple]) 1169 1170 return a 1171 1172_EDIFF1D_DOC = """\ 1173Unlike NumPy's implementation of ediff1d, :py:func:`jax.numpy.ediff1d` will not 1174issue an error if casting ``to_end`` or ``to_begin`` to the type of ``ary`` 1175loses precision. 1176""" 1177 1178@_wraps(np.ediff1d, lax_description=_EDIFF1D_DOC) 1179def ediff1d(ary, to_end=None, to_begin=None): 1180 ary = ravel(asarray(ary)) 1181 result = lax.sub(ary[1:], ary[:-1]) 1182 if to_begin is not None: 1183 result = concatenate((ravel(asarray(to_begin, dtype=ary.dtype)), result)) 1184 if to_end is not None: 1185 result = concatenate((result, ravel(asarray(to_end, dtype=ary.dtype)))) 1186 return result 1187 1188 1189@partial(jit, static_argnums=2) 1190def _gradient(a, varargs, axis): 1191 def gradient_along_axis(a, h, axis): 1192 sliced = partial(lax.slice_in_dim, a, axis=axis) 1193 a_grad = concatenate(( 1194 (sliced(1, 2) - sliced(0, 1)), # upper edge 1195 (sliced(2, None) - sliced(None, -2)) * 0.5, # inner 1196 (sliced(-1, None) - sliced(-2, -1)), # lower edge 1197 ), axis) 1198 return a_grad / h 1199 1200 if axis is None: 1201 axis = range(a.ndim) 1202 else: 1203 if isinstance(axis, int): 1204 axis = (axis,) 1205 if not isinstance(axis, tuple) and not isinstance(axis, list): 1206 raise ValueError("Give `axis` either as int or iterable") 1207 elif len(axis) == 0: 1208 return [] 1209 axis = [_canonicalize_axis(i, a.ndim) for i in axis] 1210 1211 if _min([s for i, s in enumerate(a.shape) if i in axis]) < 2: 1212 raise ValueError("Shape of array too small to calculate " 1213 "a numerical gradient, " 1214 "at least 2 elements are required.") 1215 len_axes = len(axis) 1216 n = len(varargs) 1217 if n == 0 or varargs is None: 1218 # no spacing 1219 dx = [1.0] * len_axes 1220 elif n == 1: 1221 # single value for all axes 1222 dx = varargs * len_axes 1223 elif n == len_axes: 1224 dx = varargs 1225 else: 1226 TypeError("Invalid number of spacing arguments %d" % n) 1227 1228 if ndim(dx[0]) != 0: 1229 raise NotImplementedError("Non-constant spacing not implemented") 1230 1231 # TODO: use jax.lax loop tools if possible 1232 a_grad = [gradient_along_axis(a, h, ax) for ax, h in zip(axis, dx)] 1233 1234 if len(axis) == 1: 1235 a_grad = a_grad[0] 1236 1237 return a_grad 1238 1239 1240@_wraps(np.gradient) 1241def gradient(f, *varargs, axis: Optional[Union[int, Tuple[int, ...]]] = None, 1242 edge_order=None): 1243 if edge_order is not None: 1244 raise NotImplementedError("The 'edge_order' argument to jnp.gradient is not supported.") 1245 return _gradient(f, varargs, axis) 1246 1247 1248@_wraps(np.isrealobj) 1249def isrealobj(x): 1250 return not iscomplexobj(x) 1251 1252 1253@_wraps(np.reshape) 1254def reshape(a, newshape, order="C"): 1255 try: 1256 return a.reshape(newshape, order=order) # forward to method for ndarrays 1257 except AttributeError: 1258 return _reshape(a, newshape, order=order) 1259 1260def _compute_newshape(a, newshape): 1261 """Fixes a -1 value in newshape, if present.""" 1262 # other errors, like having more than one -1, are caught downstream 1263 try: iter(newshape) 1264 except: iterable = False 1265 else: iterable = True 1266 def check(size): 1267 return size if type(size) is Poly else core.concrete_or_error( 1268 operator.index, size, "The error arose in jax.numpy.reshape.") 1269 newshape = [check(size) for size in newshape] if iterable else [check(newshape)] 1270 if np.any(np.equal(newshape, -1)): 1271 fix = -a.size // (newshape if type(newshape) is Poly else _prod(newshape)) 1272 return [d if d != -1 else fix for d in newshape] 1273 else: 1274 return newshape 1275 1276def _reshape(a, *args, order="C"): 1277 newshape = _compute_newshape(a, args[0] if len(args) == 1 else args) 1278 if order == "C": 1279 return lax.reshape(a, newshape, None) 1280 elif order == "F": 1281 dims = np.arange(ndim(a))[::-1] 1282 return lax.reshape(a, newshape[::-1], dims).T 1283 elif order == "A": 1284 raise NotImplementedError("np.reshape order=A is not implemented.") 1285 else: 1286 raise ValueError("Unexpected value for 'order' argument: {}.".format(order)) 1287 1288@_wraps(np.ravel) 1289def ravel(a, order="C"): 1290 if order == "K": 1291 raise NotImplementedError("Ravel not implemented for order='K'.") 1292 return reshape(a, (size(a),), order) 1293 1294 1295@_wraps(np.ravel_multi_index) 1296def ravel_multi_index(multi_index, dims, mode='raise', order='C'): 1297 assert len(multi_index) == len(dims), f"len(multi_index)={len(multi_index)} != len(dims)={len(dims)}" 1298 dims = tuple(core.concrete_or_error(int, d, "in `dims` argument of ravel_multi_index().") for d in dims) 1299 _check_arraylike("ravel_multi_index", *multi_index) 1300 for index in multi_index: 1301 if mode == 'raise': 1302 core.concrete_or_error(array, index, 1303 "The error occurred because ravel_multi_index was jit-compiled" 1304 " with mode='raise'. Use mode='wrap' or mode='clip' instead.") 1305 if not issubdtype(_dtype(index), integer): 1306 raise TypeError("only int indices permitted") 1307 if mode == "raise": 1308 if _any(any((i < 0) | (i >= d)) for i, d in zip(multi_index, dims)): 1309 raise ValueError("invalid entry in coordinates array") 1310 elif mode == "clip": 1311 multi_index = [clip(i, 0, d - 1) for i, d in zip(multi_index, dims)] 1312 elif mode == "wrap": 1313 multi_index = [i % d for i, d in zip(multi_index, dims)] 1314 else: 1315 raise ValueError(f"invalid mode={mode!r}. Expected 'raise', 'wrap', or 'clip'") 1316 1317 if order == "F": 1318 strides = np.cumprod((1,) + dims[:-1]) 1319 elif order == "C": 1320 strides = np.cumprod((1,) + dims[1:][::-1])[::-1] 1321 else: 1322 raise ValueError(f"invalid order={order!r}. Expected 'C' or 'F'") 1323 1324 result = 0 1325 for i, s in zip(multi_index, strides): 1326 result = result + i * s 1327 return result 1328 1329 1330_UNRAVEL_INDEX_DOC = """\ 1331Unlike numpy's implementation of unravel_index, negative indices are accepted 1332and out-of-bounds indices are clipped. 1333""" 1334 1335@_wraps(np.unravel_index, lax_description=_UNRAVEL_INDEX_DOC) 1336def unravel_index(indices, shape): 1337 indices = asarray(indices) 1338 sizes = pad(shape, (0, 1), constant_values=1) 1339 cumulative_sizes = cumprod(sizes[::-1])[::-1] 1340 total_size = cumulative_sizes[0] 1341 # Clip so raveling and unraveling an oob index will not change the behavior 1342 clipped_indices = clip(indices, -total_size, total_size - 1) 1343 # Add enough trailing dims to avoid conflict with flat_index 1344 cumulative_sizes = cumulative_sizes.reshape([-1] + [1] * indices.ndim) 1345 idx = clipped_indices % cumulative_sizes[:-1] // cumulative_sizes[1:] 1346 return tuple(idx) 1347 1348 1349@_wraps(np.squeeze) 1350def squeeze(a, axis: Optional[Union[int, Tuple[int, ...]]] = None): 1351 _check_arraylike("squeeze", a) 1352 if axis is None: 1353 a_shape = shape(a) 1354 axis = tuple(i for i, d in enumerate(a_shape) if d == 1) 1355 elif not isinstance(axis, tuple): 1356 axis = (axis,) 1357 return lax.squeeze(a, axis) 1358 1359 1360@_wraps(np.expand_dims) 1361def expand_dims(a, axis: Union[int, Tuple[int, ...]]): 1362 _check_arraylike("expand_dims", a) 1363 if not isinstance(axis, tuple): 1364 axis = (axis,) 1365 return lax.expand_dims(a, axis) 1366 1367 1368@_wraps(np.swapaxes) 1369def swapaxes(a, axis1: int, axis2: int): 1370 _check_arraylike("swapaxes", a) 1371 perm = np.arange(ndim(a)) 1372 perm[axis1], perm[axis2] = perm[axis2], perm[axis1] 1373 return lax.transpose(a, perm) 1374 1375 1376@_wraps(np.moveaxis) 1377def moveaxis(a, source: Union[int, Sequence[int]], 1378 destination: Union[int, Sequence[int]]): 1379 _check_arraylike("moveaxis", a) 1380 source_axes: Tuple[int, ...] 1381 destination_axes: Tuple[int, ...] 1382 try: 1383 source_axes = (operator.index(source),) 1384 except TypeError: 1385 source_axes = tuple(cast(Sequence[int], source)) 1386 try: 1387 destination_axes = (operator.index(destination),) 1388 except TypeError: 1389 destination_axes = tuple(cast(Sequence[int], destination)) 1390 source_axes = tuple(_canonicalize_axis(i, ndim(a)) for i in source_axes) 1391 destination_axes = tuple(_canonicalize_axis(i, ndim(a)) 1392 for i in destination_axes) 1393 if len(source_axes) != len(destination_axes): 1394 raise ValueError("Inconsistent number of elements: {} vs {}" 1395 .format(len(source_axes), len(destination_axes))) 1396 perm = [i for i in range(ndim(a)) if i not in source_axes] 1397 for dest, src in sorted(zip(destination_axes, source_axes)): 1398 perm.insert(dest, src) 1399 return lax.transpose(a, perm) 1400 1401 1402@_wraps(np.isclose) 1403def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): 1404 a, b = _promote_args("isclose", asarray(a), asarray(b)) 1405 dtype = _dtype(a) 1406 if issubdtype(dtype, inexact): 1407 if issubdtype(dtype, complexfloating): 1408 dtype = _complex_elem_type(dtype) 1409 rtol = lax.convert_element_type(rtol, dtype) 1410 atol = lax.convert_element_type(atol, dtype) 1411 out = lax.le( 1412 lax.abs(lax.sub(a, b)), 1413 lax.add(atol, lax.mul(rtol, lax.abs(b)))) 1414 # This corrects the comparisons for infinite and nan values 1415 a_inf = isinf(a) 1416 b_inf = isinf(b) 1417 any_inf = logical_or(a_inf, b_inf) 1418 both_inf = logical_and(a_inf, b_inf) 1419 # Make all elements where either a or b are infinite to False 1420 out = logical_and(out, logical_not(any_inf)) 1421 # Make all elements where both a or b are the same inf to True 1422 same_value = lax.eq(a, b) 1423 same_inf = logical_and(both_inf, same_value) 1424 out = logical_or(out, same_inf) 1425 1426 # Make all elements where either a or b is NaN to False 1427 a_nan = isnan(a) 1428 b_nan = isnan(b) 1429 any_nan = logical_or(a_nan, b_nan) 1430 out = logical_and(out, logical_not(any_nan)) 1431 if equal_nan: 1432 # Make all elements where both a and b is NaN to True 1433 both_nan = logical_and(a_nan, b_nan) 1434 out = logical_or(out, both_nan) 1435 return _maybe_numpy_1_13_isclose_behavior(a, out) 1436 else: 1437 return lax.eq(a, b) 1438 1439numpy_version = tuple(map(int, np.version.version.split('.')[:2])) 1440if numpy_version < (1, 14): 1441 # see discussion at https://github.com/numpy/numpy/pull/9720 1442 def _maybe_numpy_1_13_isclose_behavior(a, out): 1443 if size(out) == 1 and issubdtype(_dtype(a), complexfloating): 1444 return lax.reshape(out, (1,)) 1445 else: 1446 return out 1447else: 1448 def _maybe_numpy_1_13_isclose_behavior(a, out): 1449 return out 1450 1451@_wraps(np.interp) 1452def interp(x, xp, fp, left=None, right=None, period=None): 1453 if shape(xp) != shape(fp) or ndim(xp) != 1: 1454 raise ValueError("xp and fp must be one-dimensional arrays of equal size") 1455 x, xp, fp = map(asarray, _promote_dtypes_inexact(x, xp, fp)) 1456 if period is not None: 1457 if period == 0: 1458 raise ValueError(f"period must be a non-zero value; got {period}") 1459 period = abs(period) 1460 x = x % period 1461 xp = xp % period 1462 xp, fp = lax.sort_key_val(xp, fp) 1463 xp = concatenate([xp[-1:] - period, xp, xp[:1] + period]) 1464 fp = concatenate([fp[-1:], fp, fp[:1]]) 1465 1466 i = clip(searchsorted(xp, x, side='right'), 1, len(xp) - 1) 1467 df = fp[i] - fp[i - 1] 1468 dx = xp[i] - xp[i - 1] 1469 delta = x - xp[i - 1] 1470 f = where((dx == 0), fp[i], fp[i - 1] + (delta / dx) * df) 1471 1472 if period is None: 1473 f = where(x < xp[0], fp[0] if left is None else left, f) 1474 f = where(x > xp[-1], fp[-1] if right is None else right, f) 1475 return f 1476 1477 1478@_wraps(np.in1d, lax_description=""" 1479In the JAX version, the `assume_unique` argument is not referenced. 1480""") 1481def in1d(ar1, ar2, assume_unique=False, invert=False): 1482 ar1 = ravel(ar1) 1483 ar2 = ravel(ar2) 1484 # Note: an algorithm based on searchsorted has better scaling, but in practice 1485 # is very slow on accelerators because it relies on lax control flow. If XLA 1486 # ever supports binary search natively, we should switch to this: 1487 # ar2 = jnp.sort(ar2) 1488 # ind = jnp.searchsorted(ar2, ar1) 1489 # if invert: 1490 # return ar1 != ar2[ind] 1491 # else: 1492 # return ar1 == ar2[ind] 1493 if invert: 1494 return (ar1[:, None] != ar2).all(-1) 1495 else: 1496 return (ar1[:, None] == ar2).any(-1) 1497 1498@_wraps(np.setdiff1d, lax_description=""" 1499In the JAX version, the `assume_unique` argument is not referenced. 1500""") 1501def setdiff1d(ar1, ar2, assume_unique=False): 1502 ar1 = core.concrete_or_error(asarray, ar1, "The error arose in setdiff1d()") 1503 ar2 = core.concrete_or_error(asarray, ar2, "The error arose in setdiff1d()") 1504 1505 ar1 = unique(ar1) 1506 ar2 = unique(ar2) 1507 1508 idx = in1d(ar1, ar2, invert=True) 1509 return ar1[idx] 1510 1511@partial(jit, static_argnums=2) 1512def _intersect1d_sorted_mask(ar1, ar2, return_indices=False): 1513 """ 1514 Helper function for intersect1d which is jit-able 1515 """ 1516 ar = concatenate((ar1, ar2)) 1517 if return_indices: 1518 iota = lax.broadcasted_iota(np.int64, shape(ar), dimension=0) 1519 aux, indices = lax.sort_key_val(ar, iota) 1520 else: 1521 aux = sort(ar) 1522 1523 mask = aux[1:] == aux[:-1] 1524 if return_indices: 1525 return aux, mask, indices 1526 else: 1527 return aux, mask 1528 1529 1530@_wraps(np.intersect1d) 1531def intersect1d(ar1, ar2, assume_unique=False, return_indices=False): 1532 ar1 = core.concrete_or_error(asarray, ar1, "The error arose in intersect1d()") 1533 ar2 = core.concrete_or_error(asarray, ar2, "The error arose in intersect1d()") 1534 1535 if not assume_unique: 1536 if return_indices: 1537 ar1, ind1 = unique(ar1, return_index=True) 1538 ar2, ind2 = unique(ar2, return_index=True) 1539 else: 1540 ar1 = unique(ar1) 1541 ar2 = unique(ar2) 1542 else: 1543 ar1 = ravel(ar1) 1544 ar2 = ravel(ar2) 1545 1546 if return_indices: 1547 aux, mask, aux_sort_indices = _intersect1d_sorted_mask(ar1, ar2, return_indices) 1548 else: 1549 aux, mask = _intersect1d_sorted_mask(ar1, ar2, return_indices) 1550 1551 int1d = aux[:-1][mask] 1552 1553 if return_indices: 1554 ar1_indices = aux_sort_indices[:-1][mask] 1555 ar2_indices = aux_sort_indices[1:][mask] - ar1.size 1556 if not assume_unique: 1557 ar1_indices = ind1[ar1_indices] 1558 ar2_indices = ind2[ar2_indices] 1559 1560 return int1d, ar1_indices, ar2_indices 1561 else: 1562 return int1d 1563 1564 1565@_wraps(np.isin, lax_description=""" 1566In the JAX version, the `assume_unique` argument is not referenced. 1567""") 1568def isin(element, test_elements, assume_unique=False, invert=False): 1569 result = in1d(element, test_elements, assume_unique=assume_unique, invert=invert) 1570 return result.reshape(shape(element)) 1571 1572 1573# The `jit` on `where` exists to avoid materializing constants in cases like 1574# `np.where(np.zeros(1000), 7, 4)`. In op-by-op mode, we don't want to 1575# materialize the broadcast forms of scalar arguments. 1576@jit 1577def _where(condition, x=None, y=None): 1578 if x is None or y is None: 1579 raise ValueError("Either both or neither of the x and y arguments should " 1580 "be provided to jax.numpy.where, got {} and {}." 1581 .format(x, y)) 1582 if not issubdtype(_dtype(condition), bool_): 1583 condition = lax.ne(condition, zeros_like(condition)) 1584 x, y = _promote_dtypes(x, y) 1585 condition, x, y = broadcast_arrays(condition, x, y) 1586 return lax.select(condition, x, y) if np.size(x) else x 1587 1588 1589_WHERE_DOC = """\ 1590At present, JAX does not support JIT-compilation of the single-argument form 1591of :py:func:`jax.numpy.where` because its output shape is data-dependent. The 1592three-argument form does not have a data-dependent shape and can be JIT-compiled 1593successfully. 1594""" 1595 1596@_wraps(np.where, update_doc=False, lax_description=_WHERE_DOC) 1597def where(condition, x=None, y=None): 1598 if x is None and y is None: 1599 return nonzero(asarray(condition)) 1600 else: 1601 return _where(condition, x, y) 1602 1603 1604@_wraps(np.select) 1605def select(condlist, choicelist, default=0): 1606 if len(condlist) != len(choicelist): 1607 msg = "condlist must have length equal to choicelist ({} vs {})" 1608 raise ValueError(msg.format(len(condlist), len(choicelist))) 1609 if len(condlist) == 0: 1610 raise ValueError("condlist must be non-empty") 1611 choices = _promote_dtypes(default, *choicelist) 1612 choicelist = choices[1:] 1613 output = choices[0] 1614 for cond, choice in zip(condlist[::-1], choicelist[::-1]): 1615 output = where(cond, choice, output) 1616 return output 1617 1618 1619@_wraps(np.bincount, lax_description="""\ 1620Jax adds the optional `length` parameter which specifies the output length, and 1621defaults to ``x.max() + 1``. It must be specified for bincount to be compilable. 1622Values larger than the specified length will be discarded. 1623 1624Additionally, while ``np.bincount`` raises an error if the input array contains 1625negative values, ``jax.numpy.bincount`` treats negative values as zero. 1626""") 1627def bincount(x, weights=None, minlength=0, *, length=None): 1628 _check_arraylike("bincount", x) 1629 if not issubdtype(_dtype(x), integer): 1630 msg = f"x argument to bincount must have an integer type; got {x.dtype}" 1631 raise TypeError(msg) 1632 if length is None: 1633 x = core.concrete_or_error(asarray, x, 1634 "The error occured because of argument 'x' of jnp.bincount. " 1635 "To avoid this error, pass a static `length` argument.") 1636 length = max(x) + 1 1637 length = _max(length, minlength) 1638 if ndim(x) != 1: 1639 raise ValueError("only 1-dimensional input supported.") 1640 if weights is None: 1641 weights = array(1, dtype=int32) 1642 else: 1643 if shape(x) != shape(weights): 1644 raise ValueError("shape of weights must match shape of x.") 1645 return ops.index_add(zeros((length,), _dtype(weights)), ops.index[clip(x, 0)], weights) 1646 1647 1648def broadcast_arrays(*args): 1649 """Like Numpy's broadcast_arrays but doesn't return views.""" 1650 shapes = [shape(arg) for arg in args] 1651 if len(set(shapes)) == 1: 1652 return [arg if isinstance(arg, ndarray) or isscalar(arg) else array(arg) 1653 for arg in args] 1654 result_shape = lax.broadcast_shapes(*shapes) 1655 return [broadcast_to(arg, result_shape) for arg in args] 1656 1657 1658@_wraps(np.broadcast_to, lax_description="""\ 1659The JAX version does not necessarily return a view of the input. 1660""") 1661def broadcast_to(arr, shape): 1662 arr = arr if isinstance(arr, ndarray) else array(arr) 1663 shape = canonicalize_shape(shape) # check that shape is concrete 1664 arr_shape = _shape(arr) 1665 if arr_shape == shape: 1666 return arr 1667 else: 1668 nlead = len(shape) - len(arr_shape) 1669 compatible = np.equal(arr_shape, shape[nlead:]) | np.equal(arr_shape, 1) 1670 if nlead < 0 or not np.all(compatible): 1671 msg = "Incompatible shapes for broadcasting: {} and requested shape {}" 1672 raise ValueError(msg.format(arr_shape, shape)) 1673 diff, = np.where(np.not_equal(shape[nlead:], arr_shape)) 1674 new_dims = tuple(range(nlead)) + tuple(nlead + diff) 1675 kept_dims = tuple(np.delete(np.arange(len(shape)), new_dims)) 1676 return lax.broadcast_in_dim(squeeze(arr, tuple(diff)), shape, kept_dims) 1677 1678 1679def _split(op, ary, indices_or_sections, axis=0): 1680 axis = core.concrete_or_error(int, axis, f"in jax.numpy.{op} argument `axis`") 1681 size = ary.shape[axis] 1682 if isinstance(indices_or_sections, (tuple, list) + _arraylike_types): 1683 indices_or_sections = np.array( 1684 [core.concrete_or_error(np.int64, i_s, f"in jax.numpy.{op} argument 1") 1685 for i_s in indices_or_sections], np.int64) 1686 split_indices = np.concatenate([[np.int64(0)], indices_or_sections, 1687 [np.int64(size)]]) 1688 else: 1689 indices_or_sections = core.concrete_or_error(np.int64, indices_or_sections, 1690 f"in jax.numpy.{op} argument 1") 1691 part_size, r = _divmod(size, indices_or_sections) 1692 if r == 0: 1693 split_indices = np.arange(indices_or_sections + 1, 1694 dtype=np.int64) * part_size 1695 elif op == "array_split": 1696 split_indices = np.concatenate( 1697 [np.arange(r + 1, dtype=np.int64) * (part_size + 1), 1698 np.arange(indices_or_sections - r, dtype=np.int64) * part_size 1699 + ((r + 1) * (part_size + 1) - 1)]) 1700 else: 1701 raise ValueError("array split does not result in an equal division") 1702 starts, ends = [0] * ndim(ary), shape(ary) 1703 _subval = lambda x, i, v: subvals(x, [(i, v)]) 1704 return [lax.slice(ary, _subval(starts, axis, start), _subval(ends, axis, end)) 1705 for start, end in zip(split_indices[:-1], split_indices[1:])] 1706 1707@_wraps(np.split) 1708def split(ary, indices_or_sections, axis: int = 0): 1709 return _split("split", ary, indices_or_sections, axis=axis) 1710 1711def _split_on_axis(np_fun, axis): 1712 @_wraps(np_fun, update_doc=False) 1713 def f(ary, indices_or_sections): 1714 return split(ary, indices_or_sections, axis=axis) 1715 return f 1716 1717vsplit = _split_on_axis(np.vsplit, axis=0) 1718hsplit = _split_on_axis(np.hsplit, axis=1) 1719dsplit = _split_on_axis(np.dsplit, axis=2) 1720 1721@_wraps(np.array_split) 1722def array_split(ary, indices_or_sections, axis: int = 0): 1723 return _split("array_split", ary, indices_or_sections, axis=axis) 1724 1725@_wraps(np.clip) 1726def clip(a, a_min=None, a_max=None, out=None): 1727 _check_arraylike("clip", a) 1728 if out is not None: 1729 raise NotImplementedError("The 'out' argument to jnp.clip is not supported.") 1730 if a_min is None and a_max is None: 1731 raise ValueError("At most one of a_min and a_max may be None") 1732 if a_min is not None: 1733 a = maximum(a_min, a) 1734 if a_max is not None: 1735 a = minimum(a_max, a) 1736 return a 1737 1738@_wraps(np.round, update_doc=False) 1739def round(a, decimals=0, out=None): 1740 _check_arraylike("round", a) 1741 if out is not None: 1742 raise NotImplementedError("The 'out' argument to jnp.round is not supported.") 1743 dtype = _dtype(a) 1744 if issubdtype(dtype, integer): 1745 if decimals < 0: 1746 raise NotImplementedError( 1747 "integer np.round not implemented for decimals < 0") 1748 return a # no-op on integer types 1749 1750 def _round_float(x): 1751 if decimals == 0: 1752 return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN) 1753 1754 # TODO(phawkins): the strategy of rescaling the value isn't necessarily a 1755 # good one since we may be left with an incorrectly rounded value at the 1756 # end due to precision problems. As a workaround for float16, convert to 1757 # float32, 1758 x = lax.convert_element_type(x, np.float32) if dtype == np.float16 else x 1759 factor = _constant_like(x, 10 ** decimals) 1760 out = lax.div(lax.round(lax.mul(x, factor), 1761 lax.RoundingMethod.TO_NEAREST_EVEN), factor) 1762 return lax.convert_element_type(out, dtype) if dtype == np.float16 else out 1763 1764 if issubdtype(dtype, complexfloating): 1765 return lax.complex(_round_float(lax.real(a)), _round_float(lax.imag(a))) 1766 else: 1767 return _round_float(a) 1768around = round 1769 1770 1771@_wraps(np.fix) 1772def fix(x, out=None): 1773 _check_arraylike("fix", x) 1774 if out is not None: 1775 raise NotImplementedError("The 'out' argument to jnp.fix is not supported.") 1776 zero = lax._const(x, 0) 1777 return where(lax.ge(x, zero), floor(x), ceil(x)) 1778 1779 1780@_wraps(np.modf) 1781def modf(x, out=None): 1782 _check_arraylike("modf", x) 1783 if out is not None: 1784 raise NotImplementedError("The 'out' argument to jnp.modf is not supported.") 1785 whole = fix(x) 1786 return x - whole, whole 1787 1788 1789@_wraps(np.isfinite) 1790def isfinite(x): 1791 _check_arraylike("isfinite", x) 1792 dtype = _dtype(x) 1793 if issubdtype(dtype, floating): 1794 return lax.is_finite(x) 1795 elif issubdtype(dtype, complexfloating): 1796 return lax.bitwise_and(lax.is_finite(real(x)), lax.is_finite(imag(x))) 1797 else: 1798 return full_like(x, True, dtype=bool_) 1799 1800@_wraps(np.isinf) 1801def isinf(x): 1802 _check_arraylike("isinf", x) 1803 dtype = _dtype(x) 1804 if issubdtype(dtype, floating): 1805 return lax.eq(lax.abs(x), _constant_like(x, inf)) 1806 elif issubdtype(dtype, complexfloating): 1807 re = lax.real(x) 1808 im = lax.imag(x) 1809 return lax.bitwise_or(lax.eq(lax.abs(re), _constant_like(re, inf)), 1810 lax.eq(lax.abs(im), _constant_like(im, inf))) 1811 else: 1812 return full_like(x, False, dtype=bool_) 1813 1814def _isposneginf(infinity, x, out): 1815 if out is not None: 1816 raise NotImplementedError("The 'out' argument to isneginf/isposinf is not supported.") 1817 dtype = _dtype(x) 1818 if issubdtype(dtype, floating): 1819 return lax.eq(x, _constant_like(x, infinity)) 1820 elif issubdtype(dtype, complexfloating): 1821 raise ValueError("isposinf/isneginf are not well defined for complex types") 1822 else: 1823 return full_like(x, False, dtype=bool_) 1824 1825isposinf = _wraps(np.isposinf)(lambda x, out=None: _isposneginf(inf, x, out)) 1826 1827isneginf = _wraps(np.isneginf)(lambda x, out=None: _isposneginf(-inf, x, out)) 1828 1829@_wraps(np.isnan) 1830def isnan(x): 1831 _check_arraylike("isnan", x) 1832 return lax.ne(x, x) 1833 1834@_wraps(np.nan_to_num) 1835def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None): 1836 del copy 1837 _check_arraylike("nan_to_num", x) 1838 dtype = _dtype(x) 1839 if issubdtype(dtype, complexfloating): 1840 return lax.complex( 1841 nan_to_num(lax.real(x), nan=nan, posinf=posinf, neginf=neginf), 1842 nan_to_num(lax.imag(x), nan=nan, posinf=posinf, neginf=neginf)) 1843 info = finfo(dtypes.canonicalize_dtype(dtype)) 1844 posinf = info.max if posinf is None else posinf 1845 neginf = info.min if neginf is None else neginf 1846 x = where(isnan(x), _constant_like(x, nan), x) 1847 x = where(isposinf(x), _constant_like(x, posinf), x) 1848 x = where(isneginf(x), _constant_like(x, neginf), x) 1849 return x 1850 1851### Reducers 1852 1853def _reduction(a, name, np_fun, op, init_val, has_identity=True, 1854 preproc=None, bool_op=None, upcast_f16_for_computation=False, 1855 axis=None, dtype=None, out=None, keepdims=False, initial=None, where_=None): 1856 bool_op = bool_op or op 1857 if out is not None: 1858 raise NotImplementedError(f"The 'out' argument to jnp.{name} is not supported.") 1859 _check_arraylike(name, a) 1860 lax._check_user_dtype_supported(dtype, name) 1861 axis = core.concrete_or_error(None, axis, f"axis argument to jnp.{name}().") 1862 1863 if initial is None and not has_identity: 1864 if not size(a): 1865 raise ValueError(f"zero-size array to reduction operation {name} which has no identity") 1866 if where_ is not None: 1867 raise ValueError(f"reduction operation {name} does not have an identity, so to use a " 1868 f"where mask one has to specify 'initial'") 1869 1870 a = a if isinstance(a, ndarray) else asarray(a) 1871 a = preproc(a) if preproc else a 1872 dims = _reduction_dims(a, axis) 1873 result_dtype = dtypes.canonicalize_dtype(dtype or _dtype(np_fun(np.ones((), dtype=_dtype(a))))) 1874 if upcast_f16_for_computation and issubdtype(result_dtype, inexact): 1875 computation_dtype = promote_types(result_dtype, float32) 1876 else: 1877 computation_dtype = result_dtype 1878 a = lax.convert_element_type(a, computation_dtype) 1879 op = op if computation_dtype != np.bool_ else bool_op 1880 # NB: in XLA, init_val must be an identity for the op, so the user-specified 1881 # initial value must be applied afterward. 1882 init_val = _reduction_init_val(a, init_val) 1883 if where_ is not None: 1884 a = where(where_, a, init_val) 1885 result = lax.reduce(a, init_val, op, dims) 1886 if initial is not None: 1887 result = op(_reduction_init_val(a, initial), result) 1888 if keepdims: 1889 result = expand_dims(result, dims) 1890 return lax.convert_element_type(result, dtype or result_dtype) 1891 1892def _reduction_dims(a, axis): 1893 if axis is None: 1894 return tuple(range(ndim(a))) 1895 elif isinstance(axis, (np.ndarray, tuple, list)): 1896 axis = tuple(_canonicalize_axis(x, ndim(a)) for x in axis) 1897 if len(axis) != len(set(axis)): 1898 raise ValueError(f"duplicate value in 'axis': {axis}") 1899 return axis 1900 else: 1901 return (_canonicalize_axis(axis, ndim(a)),) 1902 1903def _reduction_init_val(a, init_val): 1904 a_dtype = dtypes.canonicalize_dtype(_dtype(a)) 1905 if a_dtype == 'bool': 1906 return np.array(init_val > 0, dtype=a_dtype) 1907 try: 1908 return np.array(init_val, dtype=a_dtype) 1909 except OverflowError: 1910 assert issubdtype(a_dtype, integer) 1911 sign, info = np.sign(init_val), iinfo(a_dtype) 1912 return np.array(info.min if sign < 0 else info.max, dtype=a_dtype) 1913 1914_cast_to_bool = partial(lax.convert_element_type, new_dtype=bool_) 1915 1916@_wraps(np.sum) 1917def sum(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, 1918 out=None, keepdims=None, initial=None, where=None): 1919 return _reduction(a, "sum", np.sum, lax.add, 0, 1920 bool_op=lax.bitwise_or, upcast_f16_for_computation=True, 1921 axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where_=where) 1922 1923@_wraps(np.prod) 1924def prod(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, 1925 out=None, keepdims=None, initial=None, where=None): 1926 return _reduction(a, "prod", np.prod, lax.mul, 1, 1927 bool_op=lax.bitwise_and, upcast_f16_for_computation=True, 1928 axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where_=where) 1929 1930@_wraps(np.max) 1931def max(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, 1932 keepdims=None, initial=None, where=None): 1933 return _reduction(a, "max", np.max, lax.max, -np.inf, has_identity=False, 1934 axis=axis, out=out, keepdims=keepdims, initial=initial, where_=where) 1935 1936@_wraps(np.min) 1937def min(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, 1938 keepdims=None, initial=None, where=None): 1939 return _reduction(a, "min", np.min, lax.min, np.inf, has_identity=False, 1940 axis=axis, out=out, keepdims=keepdims, initial=initial, where_=where) 1941 1942@_wraps(np.all) 1943def all(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, 1944 keepdims=None): 1945 return _reduction(a, "all", np.all, lax.bitwise_and, True, preproc=_cast_to_bool, 1946 axis=axis, out=out, keepdims=keepdims) 1947 1948@_wraps(np.any) 1949def any(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, 1950 keepdims=None): 1951 return _reduction(a, "any", np.any, lax.bitwise_or, False, preproc=_cast_to_bool, 1952 axis=axis, out=out, keepdims=keepdims) 1953 1954product = prod 1955amin = min 1956amax = max 1957alltrue = all 1958sometrue = any 1959 1960@_wraps(np.mean) 1961def mean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, 1962 out=None, keepdims=False): 1963 _check_arraylike("mean", a) 1964 lax._check_user_dtype_supported(dtype, "mean") 1965 if out is not None: 1966 raise NotImplementedError("The 'out' argument to jnp.mean is not supported.") 1967 1968 if axis is None: 1969 normalizer = size(a) 1970 else: 1971 normalizer = np.prod(np.take(shape(a), axis)) 1972 if dtype is None: 1973 if issubdtype(_dtype(a), bool_) or issubdtype(_dtype(a), integer): 1974 dtype = float_ 1975 else: 1976 dtype = _dtype(a) 1977 dtype = dtypes.canonicalize_dtype(dtype) 1978 1979 return lax.div( 1980 sum(a, axis, dtype=dtype, keepdims=keepdims), 1981 lax.convert_element_type(normalizer, dtype)) 1982 1983@_wraps(np.average) 1984def average(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, weights=None, 1985 returned=False): 1986 a = asarray(a) 1987 1988 if weights is None: # Treat all weights as 1 1989 avg = mean(a, axis=axis) 1990 if axis is None: 1991 weights_sum = full((), size(a), dtype=avg.dtype) 1992 else: 1993 weights_sum = full_like(avg, a.shape[axis], dtype=avg.dtype) 1994 else: 1995 weights = asarray(weights) 1996 1997 if issubdtype(a.dtype, inexact): 1998 out_dtype = result_type(a.dtype, weights.dtype) 1999 else: 2000 out_dtype = result_type(a.dtype, weights.dtype, float_) 2001 out_dtype = dtypes.canonicalize_dtype(out_dtype) 2002 2003 a_shape = shape(a) 2004 a_ndim = len(a_shape) 2005 weights_shape = shape(weights) 2006 axis = None if axis is None else _canonicalize_axis(axis, a_ndim) 2007 2008 if a_shape != weights_shape: 2009 # Make sure the dimensions work out 2010 if axis is None: 2011 raise ValueError("Axis must be specified when shapes of a and " 2012 "weights differ.") 2013 if len(weights_shape) != 1: 2014 raise ValueError("1D weights expected when shapes of a and " 2015 "weights differ.") 2016 if weights_shape[0] != a_shape[axis]: 2017 raise ValueError("Length of weights not " 2018 "compatible with specified axis.") 2019 2020 weights = broadcast_to(weights, (a_ndim - 1) * (1,) + weights_shape) 2021 weights = moveaxis(weights, -1, axis) 2022 2023 weights_sum = sum(weights, axis=axis, dtype=out_dtype) 2024 avg = sum(multiply(a, weights), axis=axis, dtype=out_dtype) / weights_sum 2025 2026 if returned: 2027 if avg.shape != weights_sum.shape: 2028 weights_sum = broadcast_to(weights_sum, avg.shape) 2029 return avg, weights_sum 2030 return avg 2031 2032 2033@_wraps(np.var) 2034def var(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, 2035 out=None, ddof=0, keepdims=False): 2036 _check_arraylike("var", a) 2037 lax._check_user_dtype_supported(dtype, "var") 2038 if out is not None: 2039 raise NotImplementedError("The 'out' argument to jnp.var is not supported.") 2040 2041 a_dtype, dtype = _var_promote_types(_dtype(a), dtype) 2042 a_mean = mean(a, axis, dtype=a_dtype, keepdims=True) 2043 centered = a - a_mean 2044 if issubdtype(centered.dtype, complexfloating): 2045 centered = lax.real(lax.mul(centered, lax.conj(centered))) 2046 else: 2047 centered = lax.square(centered) 2048 2049 if axis is None: 2050 normalizer = size(a) 2051 else: 2052 normalizer = np.prod(np.take(shape(a), axis)) 2053 normalizer = normalizer - ddof 2054 2055 result = sum(centered, axis, keepdims=keepdims) 2056 out = lax.div(result, lax.convert_element_type(normalizer, result.dtype)) 2057 return lax.convert_element_type(out, dtype) 2058 2059 2060def _var_promote_types(a_dtype, dtype): 2061 if dtype: 2062 if (not issubdtype(dtype, complexfloating) and 2063 issubdtype(a_dtype, complexfloating)): 2064 msg = ("jax.numpy.var does not yet support real dtype parameters when " 2065 "computing the variance of an array of complex values. The " 2066 "semantics of numpy.var seem unclear in this case. Please comment " 2067 "on https://github.com/google/jax/issues/2283 if this behavior is " 2068 "important to you.") 2069 raise ValueError(msg) 2070 a_dtype = promote_types(a_dtype, dtype) 2071 else: 2072 if not issubdtype(a_dtype, inexact): 2073 dtype = a_dtype = dtypes.canonicalize_dtype(float_) 2074 else: 2075 dtype = _complex_elem_type(a_dtype) 2076 a_dtype = promote_types(a_dtype, float32) 2077 return a_dtype, dtype 2078 2079 2080@_wraps(np.std) 2081def std(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, 2082 out=None, ddof=0, keepdims=False): 2083 _check_arraylike("std", a) 2084 lax._check_user_dtype_supported(dtype, "std") 2085 if out is not None: 2086 raise NotImplementedError("The 'out' argument to jnp.std is not supported.") 2087 return sqrt(var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims)) 2088 2089 2090@_wraps(np.ptp) 2091def ptp(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, 2092 keepdims=False): 2093 _check_arraylike("ptp", a) 2094 if out is not None: 2095 raise NotImplementedError("The 'out' argument to jnp.ptp is not supported.") 2096 x = amax(a, axis=axis, keepdims=keepdims) 2097 y = amin(a, axis=axis, keepdims=keepdims) 2098 return lax.sub(x, y) 2099 2100 2101@_wraps(np.allclose) 2102def allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): 2103 return all(isclose(a, b, rtol, atol, equal_nan)) 2104 2105 2106@_wraps(np.count_nonzero) 2107def count_nonzero(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, 2108 keepdims=False): 2109 _check_arraylike("count_nonzero", a) 2110 return sum(lax.ne(a, _constant_like(a, 0)), axis=axis, 2111 dtype=dtypes.canonicalize_dtype(np.int_), keepdims=keepdims) 2112 2113 2114_NONZERO_DOC = """\ 2115At present, JAX does not support JIT-compilation of :py:func:`jax.numpy.nonzero` 2116because its output shape is data-dependent. 2117""" 2118 2119@_wraps(np.nonzero, lax_description=_NONZERO_DOC) 2120def nonzero(a): 2121 # Note: this function cannot be jitted because its output has a dynamic 2122 # shape. 2123 a = core.concrete_or_error(atleast_1d, a, "The error arose in jnp.nonzero") 2124 dims = shape(a) 2125 ndims = len(dims) 2126 ds = [lax.broadcasted_iota(int_, dims + (1,), i) for i in range(ndims)] 2127 d = concatenate(ds, axis=-1) 2128 indexes = d[a != 0] 2129 return tuple(indexes[..., i] for i in range(ndims)) 2130 2131 2132@_wraps(np.flatnonzero) 2133def flatnonzero(a): 2134 return nonzero(ravel(a))[0] 2135 2136 2137def _nan_reduction(a, name, jnp_reduction, init_val, nan_if_all_nan, 2138 axis=None, keepdims=None, **kwargs): 2139 _check_arraylike(name, a) 2140 out = jnp_reduction(where(isnan(a), _reduction_init_val(a, init_val), a), 2141 axis=axis, keepdims=keepdims, **kwargs) 2142 if nan_if_all_nan: 2143 return where(all(isnan(a), axis=axis, keepdims=keepdims), 2144 _constant_like(a, nan), out) 2145 else: 2146 return out 2147 2148@_wraps(np.nanmin) 2149def nanmin(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, 2150 keepdims=None): 2151 return _nan_reduction(a, 'nanmin', min, inf, nan_if_all_nan=True, 2152 axis=axis, out=out, keepdims=keepdims) 2153 2154@_wraps(np.nanmax) 2155def nanmax(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, 2156 keepdims=None): 2157 return _nan_reduction(a, 'nanmax', max, -inf, nan_if_all_nan=True, 2158 axis=axis, out=out, keepdims=keepdims) 2159 2160@_wraps(np.nansum) 2161def nansum(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, 2162 out=None, keepdims=None): 2163 return _nan_reduction(a, 'nansum', sum, 0, nan_if_all_nan=False, 2164 axis=axis, dtype=dtype, out=out, keepdims=keepdims) 2165 2166@_wraps(np.nanprod) 2167def nanprod(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, 2168 out=None, keepdims=None): 2169 return _nan_reduction(a, 'nanprod', prod, 1, nan_if_all_nan=False, 2170 axis=axis, dtype=dtype, out=out, keepdims=keepdims) 2171 2172@_wraps(np.nanmean) 2173def nanmean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, 2174 out=None, keepdims=False): 2175 _check_arraylike("nanmean", a) 2176 lax._check_user_dtype_supported(dtype, "nanmean") 2177 if out is not None: 2178 raise NotImplementedError("The 'out' argument to jnp.nanmean is not supported.") 2179 if issubdtype(_dtype(a), bool_) or issubdtype(_dtype(a), integer): 2180 return mean(a, axis, dtype, out, keepdims) 2181 if dtype is None: 2182 dtype = _dtype(a) 2183 nan_mask = logical_not(isnan(a)) 2184 normalizer = sum(nan_mask, axis=axis, dtype=int32, keepdims=keepdims) 2185 normalizer = lax.convert_element_type(normalizer, dtype) 2186 td = lax.div(nansum(a, axis, dtype=dtype, keepdims=keepdims), normalizer) 2187 return td 2188 2189 2190@_wraps(np.nanvar) 2191def nanvar(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, 2192 out=None, ddof=0, keepdims=False): 2193 _check_arraylike("nanvar", a) 2194 lax._check_user_dtype_supported(dtype, "nanvar") 2195 if out is not None: 2196 raise NotImplementedError("The 'out' argument to jnp.nanvar is not supported.") 2197 2198 a_dtype, dtype = _var_promote_types(_dtype(a), dtype) 2199 a_mean = nanmean(a, axis, dtype=a_dtype, keepdims=True) 2200 centered = a - a_mean 2201 if issubdtype(centered.dtype, complexfloating): 2202 centered = lax.real(lax.mul(centered, lax.conj(centered))) 2203 else: 2204 centered = lax.square(centered) 2205 2206 normalizer = sum(logical_not(isnan(a)), axis=axis, keepdims=keepdims) 2207 normalizer = normalizer - ddof 2208 if config.omnistaging_enabled: 2209 normalizer_mask = lax.le(normalizer, 0) 2210 else: 2211 zero = lax.full_like(normalizer, 0, shape=()) 2212 normalizer_mask = lax.le(normalizer, zero) 2213 2214 result = nansum(centered, axis, keepdims=keepdims) 2215 result = where(normalizer_mask, nan, result) 2216 divisor = where(normalizer_mask, 1, normalizer) 2217 out = lax.div(result, lax.convert_element_type(divisor, result.dtype)) 2218 return lax.convert_element_type(out, dtype) 2219 2220 2221@_wraps(np.nanstd) 2222def nanstd(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, 2223 out=None, ddof=0, keepdims=False): 2224 _check_arraylike("nanstd", a) 2225 lax._check_user_dtype_supported(dtype, "nanstd") 2226 if out is not None: 2227 raise NotImplementedError("The 'out' argument to jnp.nanstd is not supported.") 2228 return sqrt(nanvar(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims)) 2229 2230 2231def _make_cumulative_reduction(np_reduction, reduction, fill_nan=False, fill_value=0): 2232 # We want to allow XLA to fuse the pad and reduce-window operators to 2233 # avoid materializing the padded output. 2234 # Consider removing `jit` once again if reduce-window is generalized to 2235 # support arbitrary padding. 2236 @partial(jit, static_argnums=(1, 2)) 2237 def _cumulative_reduction(a, axis, dtype): 2238 if axis is None or isscalar(a): 2239 a = ravel(a) 2240 axis = 0 2241 2242 a_shape = list(shape(a)) 2243 num_dims = len(a_shape) 2244 axis = _canonicalize_axis(axis, num_dims) 2245 2246 if fill_nan: 2247 a = where(isnan(a), _constant_like(a, fill_value), a) 2248 2249 if not dtype and _dtype(a) == bool_: 2250 dtype = int_ 2251 if dtype: 2252 a = lax.convert_element_type(a, dtype) 2253 2254 return reduction(a, axis) 2255 2256 @_wraps(np_reduction) 2257 def cumulative_reduction(a, 2258 axis: Optional[Union[int, Tuple[int, ...]]] = None, 2259 dtype=None, out=None): 2260 _check_arraylike(np_reduction.__name__, a) 2261 if out is not None: 2262 raise NotImplementedError(f"The 'out' argument to jnp.{np_reduction.__name__} " 2263 f"is not supported.") 2264 lax._check_user_dtype_supported(dtype, np_reduction.__name__) 2265 # jit doesn't support kwargs as static_args. 2266 return _cumulative_reduction(a, axis, dtype) 2267 return cumulative_reduction 2268 2269 2270cumsum = _make_cumulative_reduction(np.cumsum, lax.cumsum, fill_nan=False) 2271cumprod = _make_cumulative_reduction(np.cumprod, lax.cumprod, fill_nan=False) 2272cumproduct = cumprod 2273nancumsum = _make_cumulative_reduction(np.nancumsum, lax.cumsum, 2274 fill_nan=True, fill_value=0) 2275nancumprod = _make_cumulative_reduction(np.nancumprod, lax.cumprod, 2276 fill_nan=True, fill_value=1) 2277 2278 2279@_wraps(np.unwrap) 2280def unwrap(p, discont=pi, axis: int = -1): 2281 _check_arraylike("unwrap", p) 2282 dd = diff(p, axis=axis) 2283 ddmod = mod(dd + pi, 2 * pi) - pi 2284 ddmod = where((ddmod == -pi) & (dd > 0), pi, ddmod) 2285 2286 ph_correct = where(abs(dd) < discont, 0, ddmod - dd) 2287 2288 up = concatenate(( 2289 lax.slice_in_dim(p, 0, 1, axis=axis), 2290 lax.slice_in_dim(p, 1, None, axis=axis) + cumsum(ph_correct, axis=axis) 2291 ), axis=axis) 2292 2293 return up 2294 2295 2296### Array-creation functions 2297 2298def _check_no_padding(axis_padding, mode): 2299 if (axis_padding[0] > 0 or axis_padding[1] > 0): 2300 msg = "Cannot apply '{}' padding to empty axis" 2301 raise ValueError(msg.format(mode)) 2302 2303 2304def _pad_constant(array, pad_width, constant_values): 2305 nd = ndim(array) 2306 constant_values = broadcast_to(asarray(constant_values), (nd, 2)) 2307 constant_values = lax.convert_element_type(constant_values, array.dtype) 2308 for i in range(nd): 2309 widths = [(0, 0, 0)] * nd 2310 widths[i] = (pad_width[i, 0], 0, 0) 2311 array = lax.pad(array, constant_values[i, 0], widths) 2312 widths[i] = (0, pad_width[i, 1], 0) 2313 array = lax.pad(array, constant_values[i, 1], widths) 2314 return array 2315 2316 2317def _pad_wrap(array, pad_width): 2318 for i in range(ndim(array)): 2319 if array.shape[i] == 0: 2320 _check_no_padding(pad_width[i], "wrap") 2321 continue 2322 size = array.shape[i] 2323 repeats, (left_remainder, right_remainder) = _divmod(pad_width[i], size) 2324 total_repeats = repeats.sum() + 1 2325 parts = [] 2326 if left_remainder: 2327 parts += [lax.slice_in_dim(array, size - left_remainder, size, axis=i)] 2328 parts += total_repeats * [array] 2329 if right_remainder: 2330 parts += [lax.slice_in_dim(array, 0, right_remainder, axis=i)] 2331 array = lax.concatenate(parts, dimension=i) 2332 return array 2333 2334 2335def _pad_symmetric_or_reflect(array, pad_width, mode, reflect_type): 2336 assert mode in ("symmetric", "reflect") 2337 assert reflect_type in ("even", "odd") 2338 2339 for i in range(ndim(array)): 2340 if array.shape[i] == 0: 2341 _check_no_padding(pad_width[i], mode) 2342 continue 2343 2344 n = array.shape[i] 2345 offset = 1 if (mode == "reflect" and n > 1) else 0 2346 2347 def build_padding(array, padding, before): 2348 if before: 2349 edge = lax.slice_in_dim(array, 0, 1, axis=i) 2350 else: 2351 edge = lax.slice_in_dim(array, -1, None, axis=i) 2352 2353 while padding > 0: 2354 curr_pad = _min(padding, n - offset) 2355 padding -= curr_pad 2356 2357 if before: 2358 start = offset 2359 stop = offset + curr_pad 2360 else: 2361 start = -(curr_pad + offset) 2362 stop = None if (mode == "symmetric" or n == 1) else -1 2363 2364 x = lax.slice_in_dim(array, start, stop, axis=i) 2365 x = flip(x, axis=i) 2366 2367 if reflect_type == 'odd': 2368 x = 2 * edge - x 2369 if n > 1: 2370 if before: 2371 edge = lax.slice_in_dim(x, 0, 1, axis=i) 2372 else: 2373 edge = lax.slice_in_dim(x, -1, None, axis=i) 2374 2375 if before: 2376 array = lax.concatenate([x, array], dimension=i) 2377 else: 2378 array = lax.concatenate([array, x], dimension=i) 2379 return array 2380 2381 array = build_padding(array, pad_width[i, 0], before=True) 2382 array = build_padding(array, pad_width[i, 1], before=False) 2383 return array 2384 2385 2386def _pad_edge(array, pad_width): 2387 nd = ndim(array) 2388 for i in range(nd): 2389 if array.shape[i] == 0: 2390 _check_no_padding(pad_width[i], "edge") 2391 continue 2392 2393 n = array.shape[i] 2394 npad_before, npad_after = pad_width[i] 2395 2396 edge_before = lax.slice_in_dim(array, 0, 1, axis=i) 2397 pad_before = repeat(edge_before, npad_before, axis=i) 2398 2399 edge_after = lax.slice_in_dim(array, n-1, n, axis=i) 2400 pad_after = repeat(edge_after, npad_after, axis=i) 2401 2402 array = lax.concatenate([pad_before, array, pad_after], dimension=i) 2403 return array 2404 2405 2406def _pad_linear_ramp(array, pad_width, end_values): 2407 for axis in range(ndim(array)): 2408 edge_before = lax.slice_in_dim(array, 0, 1, axis=axis) 2409 edge_after = lax.slice_in_dim(array, -1, None, axis=axis) 2410 ramp_before = linspace( 2411 start=end_values[axis][0], 2412 stop=edge_before.squeeze(axis), # Dimension is replaced by linspace 2413 num=pad_width[axis][0], 2414 endpoint=False, 2415 dtype=array.dtype, 2416 axis=axis 2417 ) 2418 ramp_after = linspace( 2419 start=end_values[axis][1], 2420 stop=edge_after.squeeze(axis), # Dimension is replaced by linspace 2421 num=pad_width[axis][1], 2422 endpoint=False, 2423 dtype=array.dtype, 2424 axis=axis 2425 ) 2426 2427 # Reverse linear space in appropriate dimension 2428 ramp_after = flip(ramp_after, axis) 2429 2430 array = lax.concatenate([ramp_before, array, ramp_after], dimension=axis) 2431 return array 2432 2433 2434def _pad_stats(array, pad_width, stat_length, stat_func): 2435 nd = ndim(array) 2436 for i in range(nd): 2437 if stat_length is None: 2438 stat_before = stat_func(array, axis=i, keepdims=True) 2439 stat_after = stat_before 2440 else: 2441 array_length = array.shape[i] 2442 length_before, length_after = stat_length[i] 2443 if length_before == 0 or length_after == 0: 2444 raise ValueError("stat_length of 0 yields no value for padding") 2445 2446 # Limit stat_length to length of array. 2447 length_before = _min(length_before, array_length) 2448 length_after = _min(length_after, array_length) 2449 2450 slice_before = lax.slice_in_dim(array, 0, length_before, axis=i) 2451 slice_after = lax.slice_in_dim(array, -length_after, None, axis=i) 2452 stat_before = stat_func(slice_before, axis=i, keepdims=True) 2453 stat_after = stat_func(slice_after, axis=i, keepdims=True) 2454 2455 if np.issubdtype(array.dtype, np.integer): 2456 stat_before = round(stat_before) 2457 stat_after = round(stat_after) 2458 2459 stat_before = stat_before.astype(array.dtype) 2460 stat_after = stat_after.astype(array.dtype) 2461 2462 npad_before, npad_after = pad_width[i] 2463 pad_before = repeat(stat_before, npad_before, axis=i) 2464 pad_after = repeat(stat_after, npad_after, axis=i) 2465 2466 array = lax.concatenate([pad_before, array, pad_after], dimension=i) 2467 return array 2468 2469 2470def _pad_empty(array, pad_width): 2471 # Note: jax.numpy.empty = jax.numpy.zeros 2472 for i in range(ndim(array)): 2473 shape_before = array.shape[:i] + (pad_width[i][0],) + array.shape[i + 1:] 2474 pad_before = empty(shape_before, dtype=array.dtype) 2475 2476 shape_after = array.shape[:i] + (pad_width[i][1],) + array.shape[i + 1:] 2477 pad_after = empty(shape_after, dtype=array.dtype) 2478 array = lax.concatenate([pad_before, array, pad_after], dimension=i) 2479 return array 2480 2481 2482def _pad_func(array, pad_width, func, **kwargs): 2483 pad_width = _broadcast_to_pairs(pad_width, ndim(array), "pad_width") 2484 padded = _pad_constant(array, np.array(pad_width), 0) 2485 for axis in range(ndim(padded)): 2486 padded = apply_along_axis(func, axis, padded, pad_width[axis], axis, kwargs) 2487 return padded 2488 2489 2490def _broadcast_to_pairs(nvals, nd, name): 2491 nvals_shape = np.shape(nvals) 2492 if nvals_shape == (nd, 2): 2493 # ((before_1, after_1), ..., (before_N, after_N)) 2494 pass 2495 elif nvals_shape == (1, 2): 2496 # ((before, after),) 2497 nvals = nvals * nd 2498 elif nvals_shape == (2,): 2499 # (before, after) (not in the numpy docstring but works anyway) 2500 before, after = nvals 2501 nvals = (nvals,) * nd 2502 elif nvals_shape == (1,): 2503 # (pad,) 2504 nvals, = nvals 2505 nvals = ((nvals, nvals),) * nd 2506 elif nvals_shape == (): 2507 # pad 2508 nvals = ((nvals, nvals),) * nd 2509 else: 2510 raise ValueError(f"{name} given unexpected structure: {nvals}. " 2511 "See docstring for valid {name} formats.") 2512 return nvals 2513 2514 2515@partial(jit, static_argnums=(1, 2, 4, 5, 6)) 2516def _pad(array, pad_width, mode, constant_values, stat_length, end_values, reflect_type): 2517 array = asarray(array) 2518 nd = ndim(array) 2519 2520 if nd == 0: 2521 return array 2522 2523 stat_funcs = {"maximum": amax, "minimum": amin, 2524 "mean": mean, "median": median} 2525 2526 pad_width = _broadcast_to_pairs(pad_width, nd, "pad_width") 2527 pad_width = np.array(pad_width) 2528 assert pad_width.shape == (nd, 2), pad_width 2529 2530 if np.any(pad_width < 0): 2531 raise ValueError("index can't contain negative values") 2532 2533 if mode == "constant": 2534 return _pad_constant(array, pad_width, constant_values) 2535 2536 elif mode == "wrap": 2537 return _pad_wrap(array, pad_width) 2538 2539 elif mode in ("symmetric", "reflect"): 2540 return _pad_symmetric_or_reflect(array, pad_width, mode, reflect_type) 2541 2542 elif mode == "edge": 2543 return _pad_edge(array, pad_width) 2544 2545 elif mode == "linear_ramp": 2546 end_values = _broadcast_to_pairs(end_values, nd, "end_values") 2547 return _pad_linear_ramp(array, pad_width, end_values) 2548 2549 elif mode in stat_funcs: 2550 if stat_length is not None: 2551 stat_length = _broadcast_to_pairs(stat_length, nd, "stat_length") 2552 return _pad_stats(array, pad_width, stat_length, stat_funcs[mode]) 2553 2554 elif mode == "empty": 2555 return _pad_empty(array, pad_width) 2556 2557 else: 2558 assert False, ("Should not be reached since pad already handled unsupported and" 2559 "not implemented modes") 2560 2561 2562@_wraps(np.pad, lax_description="""\ 2563Unlike numpy, JAX "function" mode's argument (which is another function) should return 2564the modified array. This is because Jax arrays are immutable. 2565(In numpy, "function" mode's argument should modify a rank 1 array in-place.) 2566""") 2567def pad(array, pad_width, mode="constant", **kwargs): 2568 if isinstance(pad_width, Iterable): 2569 pad_width = tuple( 2570 tuple(int(i) for i in x) if isinstance(x, Iterable) else x 2571 for x in pad_width) 2572 2573 if callable(mode): 2574 return _pad_func(array, pad_width, mode, **kwargs) 2575 2576 allowed_kwargs = { 2577 'empty': [], 'edge': [], 'wrap': [], 2578 'constant': ['constant_values'], 2579 'linear_ramp': ['end_values'], 2580 'maximum': ['stat_length'], 2581 'mean': ['stat_length'], 2582 'median': ['stat_length'], 2583 'minimum': ['stat_length'], 2584 'reflect': ['reflect_type'], 2585 'symmetric': ['reflect_type'], 2586 } 2587 try: 2588 unsupported_kwargs = set(kwargs) - set(allowed_kwargs[mode]) 2589 except KeyError: 2590 msg = "Unimplemented padding mode '{}' for np.pad." 2591 raise NotImplementedError(msg.format(mode)) 2592 if unsupported_kwargs: 2593 raise ValueError("unsupported keyword arguments for mode '{}': {}" 2594 .format(mode, unsupported_kwargs)) 2595 # Set default value if not given. 2596 constant_values = kwargs.get('constant_values', 0) 2597 stat_length = kwargs.get('stat_length', None) 2598 end_values = kwargs.get('end_values', 0) 2599 reflect_type = kwargs.get('reflect_type', "even") 2600 2601 return _pad(array, pad_width, mode, constant_values, stat_length, end_values, reflect_type) 2602 2603 2604@_wraps(np.stack) 2605def stack(arrays, axis: int =0, out=None): 2606 if not len(arrays): 2607 raise ValueError("Need at least one array to stack.") 2608 if out is not None: 2609 raise NotImplementedError("The 'out' argument to jnp.stack is not supported.") 2610 _check_arraylike("stack", *arrays) 2611 shape0 = shape(arrays[0]) 2612 axis = _canonicalize_axis(axis, len(shape0) + 1) 2613 new_arrays = [] 2614 for a in arrays: 2615 if shape(a) != shape0: 2616 raise ValueError("All input arrays must have the same shape.") 2617 new_arrays.append(expand_dims(a, axis)) 2618 return concatenate(new_arrays, axis=axis) 2619 2620@_wraps(np.tile) 2621def tile(A, reps): 2622 _check_arraylike("tile", A) 2623 try: 2624 iter(reps) 2625 except TypeError: 2626 reps = (reps,) 2627 reps = tuple(operator.index(rep) for rep in reps) 2628 A_shape = (1,) * (len(reps) - ndim(A)) + shape(A) 2629 reps = (1,) * (len(A_shape) - len(reps)) + reps 2630 result = broadcast_to(reshape(A, [j for i in A_shape for j in [1, i]]), 2631 [k for pair in zip(reps, A_shape) for k in pair]) 2632 return reshape(result, tuple(np.multiply(A_shape, reps))) 2633 2634@_wraps(np.concatenate) 2635def concatenate(arrays, axis: int = 0): 2636 _check_arraylike("concatenate", *arrays) 2637 if not len(arrays): 2638 raise ValueError("Need at least one array to concatenate.") 2639 if ndim(arrays[0]) == 0: 2640 raise ValueError("Zero-dimensional arrays cannot be concatenated.") 2641 if axis is None: 2642 return concatenate([ravel(a) for a in arrays], axis=0) 2643 axis = _canonicalize_axis(axis, ndim(arrays[0])) 2644 arrays = _promote_dtypes(*arrays) 2645 # lax.concatenate can be slow to compile for wide concatenations, so form a 2646 # tree of concatenations as a workaround especially for op-by-op mode. 2647 # (https://github.com/google/jax/issues/653). 2648 k = 16 2649 if len(arrays) == 1: 2650 return asarray(arrays[0]) 2651 else: 2652 while len(arrays) > 1: 2653 arrays = [lax.concatenate(arrays[i:i+k], axis) 2654 for i in range(0, len(arrays), k)] 2655 return arrays[0] 2656 2657 2658@_wraps(np.vstack) 2659def vstack(tup): 2660 return concatenate([atleast_2d(m) for m in tup], axis=0) 2661row_stack = vstack 2662 2663 2664@_wraps(np.hstack) 2665def hstack(tup): 2666 arrs = [atleast_1d(m) for m in tup] 2667 if arrs[0].ndim == 1: 2668 return concatenate(arrs, 0) 2669 return concatenate(arrs, 1) 2670 2671 2672@_wraps(np.dstack) 2673def dstack(tup): 2674 return concatenate([atleast_3d(m) for m in tup], axis=2) 2675 2676 2677@_wraps(np.column_stack) 2678def column_stack(tup): 2679 arrays = [] 2680 for v in tup: 2681 arr = asarray(v) 2682 if arr.ndim < 2: 2683 arr = atleast_2d(arr).T 2684 arrays.append(arr) 2685 return concatenate(arrays, 1) 2686 2687 2688@_wraps(np.choose) 2689def choose(a, choices, out=None, mode='raise'): 2690 if out is not None: 2691 raise NotImplementedError("The 'out' argument to jnp.choose is not supported.") 2692 _check_arraylike('choose', a, *choices) 2693 if not issubdtype(_dtype(a), integer): 2694 raise ValueError("`a` array must be integer typed") 2695 N = len(choices) 2696 2697 if mode == 'raise': 2698 a = core.concrete_or_error(asarray, a, 2699 "The error occurred because jnp.choose was jit-compiled" 2700 " with mode='raise'. Use mode='wrap' or mode='clip' instead.") 2701 if any((a < 0) | (a >= N)): 2702 raise ValueError("invalid entry in choice array") 2703 elif mode == 'wrap': 2704 a = a % N 2705 elif mode == 'clip': 2706 a = clip(a, 0, N - 1) 2707 else: 2708 raise ValueError(f"mode={mode!r} not understood. Must be 'raise', 'wrap', or 'clip'") 2709 2710 a, *choices = broadcast_arrays(a, *choices) 2711 return array(choices)[(a,) + indices(a.shape, sparse=True)] 2712 2713 2714def _atleast_nd(x, n): 2715 m = ndim(x) 2716 return lax.broadcast(x, (1,) * (n - m)) if m < n else x 2717 2718def _block(xs): 2719 if isinstance(xs, tuple): 2720 raise ValueError("jax.numpy.block does not allow tuples, got {}" 2721 .format(xs)) 2722 elif isinstance(xs, list): 2723 if len(xs) == 0: 2724 raise ValueError("jax.numpy.block does not allow empty list arguments") 2725 xs, depths = unzip2([_block(x) for x in xs]) 2726 if _any(d != depths[0] for d in depths[1:]): 2727 raise ValueError("Mismatched list depths in jax.numpy.block") 2728 rank = _max(depths[0], _max(ndim(x) for x in xs)) 2729 xs = [_atleast_nd(x, rank) for x in xs] 2730 return concatenate(xs, axis=-depths[0]), depths[0] + 1 2731 else: 2732 return asarray(xs), 1 2733 2734@_wraps(np.block) 2735@jit 2736def block(arrays): 2737 out, _ = _block(arrays) 2738 return out 2739 2740 2741@_wraps(np.atleast_1d, update_doc=False) 2742def atleast_1d(*arys): 2743 if len(arys) == 1: 2744 arr = asarray(arys[0]) 2745 return arr if ndim(arr) >= 1 else reshape(arr, -1) 2746 else: 2747 return [atleast_1d(arr) for arr in arys] 2748 2749 2750@_wraps(np.atleast_2d, update_doc=False) 2751def atleast_2d(*arys): 2752 if len(arys) == 1: 2753 arr = asarray(arys[0]) 2754 if ndim(arr) >= 2: 2755 return arr 2756 elif ndim(arr) == 1: 2757 return expand_dims(arr, axis=0) 2758 else: 2759 return expand_dims(arr, axis=(0, 1)) 2760 else: 2761 return [atleast_2d(arr) for arr in arys] 2762 2763 2764@_wraps(np.atleast_3d, update_doc=False) 2765def atleast_3d(*arys): 2766 if len(arys) == 1: 2767 arr = asarray(arys[0]) 2768 if ndim(arr) == 0: 2769 arr = expand_dims(arr, axis=(0, 1, 2)) 2770 elif ndim(arr) == 1: 2771 arr = expand_dims(arr, axis=(0, 2)) 2772 elif ndim(arr) == 2: 2773 arr = expand_dims(arr, axis=2) 2774 return arr 2775 else: 2776 return [atleast_3d(arr) for arr in arys] 2777 2778 2779@_wraps(np.array) 2780def array(object, dtype=None, copy=True, order="K", ndmin=0): 2781 if order is not None and order != "K": 2782 raise NotImplementedError("Only implemented for order='K'") 2783 lax._check_user_dtype_supported(dtype, "array") 2784 dtype = dtype and dtypes.canonicalize_dtype(dtype) 2785 2786 if _can_call_numpy_array(object): 2787 object = _np_array(object, dtype=dtype, ndmin=ndmin, copy=False) 2788 assert type(object) not in dtypes.python_scalar_dtypes 2789 2790 if type(object) is np.ndarray: 2791 out = _device_put_raw(object) 2792 if dtype: assert _dtype(out) == dtype 2793 elif isinstance(object, (DeviceArray, core.Tracer)): 2794 if isinstance(object, DeviceArray) and copy: 2795 # We perform a copy by bouncing back to the host 2796 # TODO(phawkins): add a device runtime function to copy a buffer 2797 out = _device_put_raw(_np_asarray(object)) 2798 else: 2799 out = object 2800 elif isinstance(object, (list, tuple)): 2801 if object: 2802 out = stack([asarray(elt, dtype=dtype) for elt in object]) 2803 else: 2804 out = _device_put_raw(_np_array([], dtype=dtype)) 2805 else: 2806 try: 2807 view = memoryview(object) 2808 except TypeError: 2809 pass # `object` does not support the buffer interface. 2810 else: 2811 return array(_np_asarray(view), dtype, copy) 2812 2813 raise TypeError("Unexpected input type for array: {}".format(type(object))) 2814 2815 if dtype and _dtype(out) != dtype: 2816 out = lax.convert_element_type(out, dtype) 2817 2818 if ndmin > ndim(out): 2819 out = lax.broadcast(out, (1,) * (ndmin - ndim(out))) 2820 return out 2821 2822def _can_call_numpy_array(x): 2823 return _all(not isinstance(l, (core.Tracer, DeviceArray)) 2824 for l in tree_leaves(x)) 2825 2826 2827@_wraps(np.asarray) 2828def asarray(a, dtype=None, order=None): 2829 lax._check_user_dtype_supported(dtype, "asarray") 2830 dtype = dtypes.canonicalize_dtype(dtype) if dtype is not None else dtype 2831 return array(a, dtype=dtype, copy=False, order=order) 2832 2833 2834@_wraps(np.zeros_like) 2835def zeros_like(a, dtype=None, shape=None): 2836 _check_arraylike("zeros_like", a) 2837 lax._check_user_dtype_supported(dtype, "zeros_like") 2838 if np.isscalar(shape): 2839 shape = (shape,) 2840 return lax.full_like(a, 0, dtype, shape) 2841 2842 2843@_wraps(np.ones_like) 2844def ones_like(a, dtype=None, shape=None): 2845 _check_arraylike("ones_like", a) 2846 lax._check_user_dtype_supported(dtype, "ones_like") 2847 if np.isscalar(shape): 2848 shape = (shape,) 2849 return lax.full_like(a, 1, dtype, shape) 2850 2851 2852@_wraps(np.full) 2853def full(shape, fill_value, dtype=None): 2854 lax._check_user_dtype_supported(dtype, "full") 2855 shape = (shape,) if ndim(shape) == 0 else shape 2856 return lax.full(shape, fill_value, dtype) 2857 2858 2859@_wraps(np.full_like) 2860def full_like(a, fill_value, dtype=None, shape=None): 2861 _check_arraylike("full_like", a) 2862 lax._check_user_dtype_supported(dtype, "full_like") 2863 if np.isscalar(shape): 2864 shape = (shape,) 2865 return lax.full_like(a, fill_value, dtype, shape) 2866 2867 2868@_wraps(np.zeros) 2869def zeros(shape, dtype=None): 2870 if isinstance(shape, types.GeneratorType): 2871 raise TypeError("expected sequence object with len >= 0 or a single integer") 2872 lax._check_user_dtype_supported(dtype, "zeros") 2873 dtype = float_ if dtype is None else dtype 2874 shape = (shape,) if ndim(shape) == 0 else shape 2875 return lax.full(shape, 0, dtype) 2876 2877@_wraps(np.ones) 2878def ones(shape, dtype=None): 2879 if isinstance(shape, types.GeneratorType): 2880 raise TypeError("expected sequence object with len >= 0 or a single integer") 2881 lax._check_user_dtype_supported(dtype, "ones") 2882 dtype = float_ if dtype is None else dtype 2883 shape = (shape,) if ndim(shape) == 0 else shape 2884 return lax.full(shape, 1, dtype) 2885 2886 2887@_wraps(np.array_equal) 2888def array_equal(a1, a2, equal_nan=False): 2889 try: 2890 a1, a2 = asarray(a1), asarray(a2) 2891 except Exception: 2892 return False 2893 if shape(a1) != shape(a2): 2894 return False 2895 eq = asarray(a1 == a2) 2896 if equal_nan: 2897 eq = logical_or(eq, logical_and(isnan(a1), isnan(a2))) 2898 return all(eq) 2899 2900 2901@_wraps(np.array_equiv) 2902def array_equiv(a1, a2): 2903 try: 2904 a1, a2 = asarray(a1), asarray(a2) 2905 except Exception: 2906 return False 2907 try: 2908 eq = equal(a1, a2) 2909 except ValueError: 2910 # shapes are not broadcastable 2911 return False 2912 return all(eq) 2913 2914 2915# We can't create uninitialized arrays in XLA; use zeros for empty. 2916empty_like = zeros_like 2917empty = zeros 2918 2919 2920@_wraps(np.eye) 2921def eye(N, M=None, k=0, dtype=None): 2922 lax._check_user_dtype_supported(dtype, "eye") 2923 dtype = float_ if dtype is None else dtype 2924 N = operator.index(N) 2925 M = N if M is None else operator.index(M) 2926 if N < 0 or M < 0: 2927 raise ValueError(f"negative dimensions are not allowed, got {N} and {M}") 2928 k = operator.index(k) 2929 return lax._eye(dtype, (N, M), k) 2930 2931 2932@_wraps(np.identity) 2933def identity(n, dtype=None): 2934 lax._check_user_dtype_supported(dtype, "identity") 2935 return eye(n, dtype=dtype) 2936 2937 2938@_wraps(np.arange) 2939def arange(start, stop=None, step=None, dtype=None): 2940 lax._check_user_dtype_supported(dtype, "arange") 2941 require = partial(core.concrete_or_error, _np_asarray) 2942 msg = "It arose in jax.numpy.arange argument `{}`.".format 2943 if stop is None and step is None: 2944 start = require(start, msg("stop")) 2945 dtype = dtype or _dtype(start) 2946 return lax.iota(dtype, np.ceil(start)) # avoids materializing 2947 else: 2948 start = require(start, msg("start")) 2949 stop = None if stop is None else require(stop, msg("stop")) 2950 step = None if step is None else require(step, msg("step")) 2951 if dtype is None: 2952 dtype = _dtype(start, *(x for x in [stop, step] if x is not None)) 2953 return array(np.arange(start, stop=stop, step=step, dtype=dtype)) 2954 2955 2956def _wrap_numpy_nullary_function(f): 2957 """Adapts `f` to return a DeviceArray instead of an np.ndarray. 2958 2959 `f` cannot have any non-static array arguments. 2960 """ 2961 @_wraps(f, update_doc=False) 2962 def wrapper(*args, **kwargs): 2963 return asarray(f(*args, **kwargs)) 2964 return wrapper 2965 2966 2967@_wraps(np.linspace) 2968def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, 2969 axis: int = 0): 2970 """Implementation of linspace differentiable in start and stop args.""" 2971 lax._check_user_dtype_supported(dtype, "linspace") 2972 if num < 0: 2973 raise ValueError("Number of samples, %s, must be non-negative." % num) 2974 2975 dtype = dtype or result_type(start, stop, dtypes.canonicalize_dtype(float_)) 2976 computation_dtype = promote_types(dtype, dtypes.canonicalize_dtype(float_)) 2977 start = asarray(start, dtype=computation_dtype) 2978 stop = asarray(stop, dtype=computation_dtype) 2979 2980 bounds_shape = list(lax.broadcast_shapes(shape(start), shape(stop))) 2981 broadcast_start = broadcast_to(start, bounds_shape) 2982 broadcast_stop = broadcast_to(stop, bounds_shape) 2983 axis = len(bounds_shape) + axis + 1 if axis < 0 else axis 2984 bounds_shape.insert(axis, 1) 2985 iota_shape = [1,] * len(bounds_shape) 2986 iota_shape[axis] = num 2987 div = (num - 1) if endpoint else num 2988 if num > 1: 2989 delta = lax.convert_element_type(stop - start, computation_dtype) / div 2990 if issubdtype(dtype, integer): 2991 # This is similar to how numpy computes linspace, but it 2992 # can fail to recover the endpoints in float32 arithmetic. 2993 out = (reshape(broadcast_start, bounds_shape) + 2994 reshape(lax.iota(dtype, num), iota_shape) * 2995 reshape(delta, bounds_shape)) 2996 else: 2997 # This approach recovers the endpoints with float32 arithmetic, 2998 # but can lead to rounding errors for integer outputs. 2999 step = reshape(lax.iota(computation_dtype, num), iota_shape) / div 3000 out = (reshape(broadcast_start, bounds_shape) * (1 - step) + 3001 reshape(broadcast_stop, bounds_shape) * step) 3002 elif num == 1: 3003 delta = nan if endpoint else stop - start 3004 out = reshape(broadcast_start, bounds_shape) 3005 else: # num == 0 degenerate case, match numpy behavior 3006 empty_shape = list(lax.broadcast_shapes(shape(start), shape(stop))) 3007 empty_shape.insert(axis, 0) 3008 delta = nan 3009 out = reshape(array([], dtype=dtype), empty_shape) 3010 if retstep: 3011 return lax.convert_element_type(out, dtype), delta 3012 else: 3013 return lax.convert_element_type(out, dtype) 3014 3015 3016@_wraps(np.logspace) 3017def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, 3018 axis: int = 0): 3019 """Implementation of logspace differentiable in start and stop args.""" 3020 lax._check_user_dtype_supported(dtype, "logspace") 3021 dtype = dtype or result_type(start, stop, dtypes.canonicalize_dtype(float_)) 3022 computation_dtype = promote_types(dtype, dtypes.canonicalize_dtype(float_)) 3023 start = asarray(start, dtype=computation_dtype) 3024 stop = asarray(stop, dtype=computation_dtype) 3025 lin = linspace(start, stop, num, 3026 endpoint=endpoint, retstep=False, dtype=None, axis=axis) 3027 return lax.convert_element_type(power(base, lin), dtype) 3028 3029 3030@_wraps(np.geomspace) 3031def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis: int = 0): 3032 """Implementation of geomspace differentiable in start and stop args.""" 3033 lax._check_user_dtype_supported(dtype, "geomspace") 3034 dtype = dtype or result_type(start, stop, dtypes.canonicalize_dtype(float_)) 3035 computation_dtype = promote_types(dtype, dtypes.canonicalize_dtype(float_)) 3036 start = asarray(start, dtype=computation_dtype) 3037 stop = asarray(stop, dtype=computation_dtype) 3038 # follow the numpy geomspace convention for negative and complex endpoints 3039 signflip = 1 - (1 - sign(real(start))) * (1 - sign(real(stop))) // 2 3040 res = signflip * logspace(log10(signflip * start), 3041 log10(signflip * stop), num, 3042 endpoint=endpoint, base=10.0, 3043 dtype=computation_dtype, axis=0) 3044 if axis != 0: 3045 res = moveaxis(res, 0, axis) 3046 return lax.convert_element_type(res, dtype) 3047 3048 3049@_wraps(np.meshgrid) 3050def meshgrid(*args, **kwargs): 3051 indexing = kwargs.get("indexing", "xy") 3052 sparse = kwargs.get("sparse", False) 3053 copy = kwargs.get("copy", True) 3054 if not copy: 3055 raise ValueError("jax.numpy.meshgrid only supports copy=True") 3056 3057 args = list(args) 3058 if indexing == "xy": 3059 if len(args) >= 2: 3060 args[0], args[1] = args[1], args[0] 3061 elif indexing != "ij": 3062 raise ValueError("Valid values for indexing are 'xy' and 'ij', got {}" 3063 .format(indexing)) 3064 3065 shape = [] 3066 for i, a in enumerate(args): 3067 args[i] = a = asarray(a) 3068 if len(a.shape) != 1: 3069 msg = "Arguments to jax.numpy.meshgrid must be 1D, got shape {}" 3070 raise ValueError(msg.format(a.shape)) 3071 shape.append(1 if sparse else a.shape[0]) 3072 3073 output = [] 3074 for i, a in enumerate(args): 3075 a = asarray(a) 3076 s = shape 3077 if sparse: 3078 s = list(s) 3079 s[i] = a.shape[0] 3080 output.append(lax.broadcast_in_dim(a, s, (i,))) 3081 3082 if indexing == "xy" and len(args) >= 2: 3083 output[0], output[1] = output[1], output[0] 3084 3085 return output 3086 3087 3088@_wraps(np.i0) 3089def i0(x): 3090 x = lax.abs(*_promote_args_inexact("i0", x)) 3091 return lax.mul(lax.exp(x), lax.bessel_i0e(x)) 3092 3093 3094@_wraps(np.ix_) 3095def ix_(*args): 3096 n = len(args) 3097 output = [] 3098 for i, a in enumerate(args): 3099 a = asarray(a) 3100 if len(a.shape) != 1: 3101 msg = "Arguments to jax.numpy.ix_ must be 1-dimensional, got shape {}" 3102 raise ValueError(msg.format(a.shape)) 3103 if _dtype(a) == bool_: 3104 raise NotImplementedError( 3105 "Boolean arguments to jax.numpy.ix_ are not implemented") 3106 shape = [1] * n 3107 shape[i] = a.shape[0] 3108 if a.size == 0: 3109 # Numpy uses an integer index type for empty arrays. 3110 output.append(lax.full(shape, np.zeros((), np.intp))) 3111 else: 3112 output.append(lax.broadcast_in_dim(a, shape, (i,))) 3113 return tuple(output) 3114 3115 3116@_wraps(np.indices) 3117def indices(dimensions, dtype=int32, sparse=False): 3118 dimensions = tuple( 3119 core.concrete_or_error(int, d, "dimensions argument of jnp.indices") 3120 for d in dimensions) 3121 N = len(dimensions) 3122 output = [] 3123 s = dimensions 3124 for i, dim in enumerate(dimensions): 3125 idx = lax.iota(dtype, dim) 3126 if sparse: 3127 s = (1,)*i + (dim,) + (1,)*(N - i - 1) 3128 output.append(lax.broadcast_in_dim(idx, s, (i,))) 3129 if sparse: 3130 return tuple(output) 3131 return stack(output, 0) if output else array([], dtype=dtype) 3132 3133 3134_TOTAL_REPEAT_LENGTH_DOC = """\ 3135Jax adds the optional `total_repeat_length` parameter which specifies the total 3136number of repeat, and defaults to sum(repeats). It must be specified for repeat 3137to be compilable. If `sum(repeats)` is larger than the specified 3138`total_repeat_length` the remaining values will be discarded. In the case of 3139`sum(repeats)` being smaller than the specified target length, the final value 3140will be repeated. 3141""" 3142 3143 3144@_wraps(np.repeat, lax_description=_TOTAL_REPEAT_LENGTH_DOC) 3145def repeat(a, repeats, axis: Optional[int] = None, *, total_repeat_length=None): 3146 _check_arraylike("repeat", a) 3147 3148 if axis is None: 3149 a = ravel(a) 3150 axis = 0 3151 3152 # If total_repeat_length is not given, can't compile, use a default. 3153 if total_repeat_length is None: 3154 repeats = core.concrete_or_error(np.array, repeats, 3155 "When jit-compiling jnp.repeat, the total number of repeats must be static. " 3156 "To fix this, either specify a static value for `repeats`, or pass a static " 3157 "value to `total_repeat_length`.") 3158 3159 # Fast path for when repeats is a scalar. 3160 if np.ndim(repeats) == 0 and ndim(a) != 0: 3161 input_shape = a.shape 3162 aux_axis = axis if axis < 0 else axis + 1 3163 a = expand_dims(a, aux_axis) 3164 reps = [1] * len(a.shape) 3165 reps[aux_axis] = repeats 3166 a = tile(a, reps) 3167 result_shape = list(input_shape) 3168 result_shape[axis] *= repeats 3169 return reshape(a, result_shape) 3170 3171 repeats = np.ravel(repeats) 3172 if ndim(a) != 0: 3173 repeats = np.broadcast_to(repeats, [a.shape[axis]]) 3174 total_repeat_length = np.sum(repeats) 3175 else: 3176 repeats = ravel(repeats) 3177 if ndim(a) != 0: 3178 repeats = broadcast_to(repeats, [a.shape[axis]]) 3179 3180 # Special case when a is a scalar. 3181 if ndim(a) == 0: 3182 if repeats.shape == (1,): 3183 return full([total_repeat_length], a) 3184 else: 3185 raise ValueError('`repeat` with a scalar parameter `a` is only ' 3186 'implemented for scalar values of the parameter `repeats`.') 3187 3188 # Special case if total_repeat_length is zero. 3189 if total_repeat_length == 0: 3190 result_shape = list(a.shape) 3191 result_shape[axis] = 0 3192 return reshape(array([], dtype=a.dtype), result_shape) 3193 3194 # If repeats is on a zero sized axis, then return the array. 3195 if a.shape[axis] == 0: 3196 return a 3197 3198 # This implementation of repeat avoid having to instantiate a large. 3199 # intermediate tensor. 3200 3201 # Modify repeats from e.g. [1,2,0,5] -> [0,1,2,0] for exclusive repeat. 3202 exclusive_repeats = roll(repeats, shift=1).at[0].set(0) 3203 # Cumsum to get indices of new number in repeated tensor, e.g. [0, 1, 3, 3] 3204 scatter_indices = cumsum(exclusive_repeats) 3205 # Scatter these onto a zero buffer, e.g. [1,1,0,2,0,0,0,0] 3206 block_split_indicators = ops.index_add( 3207 x=zeros([total_repeat_length], dtype=int32), 3208 idx=scatter_indices, 3209 y=1) 3210 # Cumsum again to get scatter indices for repeat, e.g. [0,1,1,3,3,3,3,3] 3211 gather_indices = cumsum(block_split_indicators) - 1 3212 return take(a, gather_indices, axis=axis) 3213 3214 3215@_wraps(np.tri) 3216def tri(N, M=None, k=0, dtype=None): 3217 lax._check_user_dtype_supported(dtype, "tri") 3218 M = M if M is not None else N 3219 dtype = dtype or float32 3220 return lax._tri(dtype, (N, M), k) 3221 3222 3223@_wraps(np.tril) 3224def tril(m, k=0): 3225 _check_arraylike("tril", m) 3226 m_shape = shape(m) 3227 if len(m_shape) < 2: 3228 raise ValueError("Argument to jax.numpy.tril must be at least 2D") 3229 mask = tri(*m_shape[-2:], k=k, dtype=bool) 3230 return lax.select(lax.broadcast(mask, m_shape[:-2]), m, zeros_like(m)) 3231 3232 3233@_wraps(np.triu, update_doc=False) 3234def triu(m, k=0): 3235 _check_arraylike("triu", m) 3236 m_shape = shape(m) 3237 if len(m_shape) < 2: 3238 raise ValueError("Argument to jax.numpy.triu must be at least 2D") 3239 mask = tri(*m_shape[-2:], k=k - 1, dtype=bool) 3240 return lax.select(lax.broadcast(mask, m_shape[:-2]), zeros_like(m), m) 3241 3242 3243@_wraps(np.trace) 3244def trace(a, offset=0, axis1: int = 0, axis2: int = 1, dtype=None, out=None): 3245 _check_arraylike("trace", a) 3246 if out is not None: 3247 raise NotImplementedError("The 'out' argument to jnp.trace is not supported.") 3248 lax._check_user_dtype_supported(dtype, "trace") 3249 3250 axis1 = _canonicalize_axis(axis1, ndim(a)) 3251 axis2 = _canonicalize_axis(axis2, ndim(a)) 3252 3253 a_shape = shape(a) 3254 if dtype is None: 3255 dtype = _dtype(a) 3256 if issubdtype(dtype, integer): 3257 default_int = dtypes.canonicalize_dtype(np.int_) 3258 if iinfo(dtype).bits < iinfo(default_int).bits: 3259 dtype = default_int 3260 3261 # Move the axis? dimensions to the end. 3262 perm = [i for i in range(len(a_shape)) if i != axis1 and i != axis2] 3263 perm = perm + [axis1, axis2] 3264 a = lax.transpose(a, perm) 3265 3266 # Mask out the diagonal and reduce. 3267 a = where(eye(a_shape[axis1], a_shape[axis2], k=offset, dtype=bool), 3268 a, zeros_like(a)) 3269 return sum(a, axis=(-2, -1), dtype=dtype) 3270 3271 3272def _wrap_indices_function(f): 3273 @_wraps(f, update_doc=False) 3274 def wrapper(*args, **kwargs): 3275 return tuple(asarray(x) for x in f(*args, **kwargs)) 3276 return wrapper 3277 3278tril_indices = _wrap_indices_function(np.tril_indices) 3279triu_indices = _wrap_indices_function(np.triu_indices) 3280mask_indices = _wrap_indices_function(np.mask_indices) 3281 3282 3283@_wraps(np.triu_indices_from) 3284def triu_indices_from(arr, k=0): 3285 return triu_indices(arr.shape[-2], k=k, m=arr.shape[-1]) 3286 3287 3288@_wraps(np.tril_indices_from) 3289def tril_indices_from(arr, k=0): 3290 return tril_indices(arr.shape[-2], k=k, m=arr.shape[-1]) 3291 3292 3293@_wraps(np.diag_indices) 3294def diag_indices(n, ndim=2): 3295 if n < 0: 3296 raise ValueError("n argument to diag_indices must be nonnegative, got {}" 3297 .format(n)) 3298 if ndim < 0: 3299 raise ValueError("ndim argument to diag_indices must be nonnegative, got {}" 3300 .format(ndim)) 3301 return (lax.iota(int_, n),) * ndim 3302 3303@_wraps(np.diag_indices_from) 3304def diag_indices_from(arr): 3305 _check_arraylike("diag_indices_from", arr) 3306 if not arr.ndim >= 2: 3307 raise ValueError("input array must be at least 2-d") 3308 3309 if len(set(arr.shape)) != 1: 3310 raise ValueError("All dimensions of input must be of equal length") 3311 3312 return diag_indices(arr.shape[0], ndim=arr.ndim) 3313 3314@_wraps(np.diagonal) 3315def diagonal(a, offset=0, axis1: int = 0, axis2: int = 1): 3316 _check_arraylike("diagonal", a) 3317 a_shape = shape(a) 3318 a_ndims = len(a_shape) 3319 3320 # Move the two dimensions to the end. 3321 axis1 = _canonicalize_axis(axis1, a_ndims) 3322 axis2 = _canonicalize_axis(axis2, a_ndims) 3323 perm = [i for i in range(a_ndims) if i != axis1 and i != axis2] 3324 perm = perm + [axis1, axis2] 3325 a = lax.transpose(a, perm) 3326 3327 # Mask out the diagonal and reduce over one of the axes 3328 a = where(eye(a_shape[axis1], a_shape[axis2], k=offset, dtype=bool), 3329 a, zeros_like(a)) 3330 reduce_axis = -2 if offset < 0 else -1 3331 d = sum(a, axis=reduce_axis, dtype=_dtype(a)) 3332 3333 # Slice out the correct diagonal size. 3334 diag_size = _max(0, _min(a_shape[axis1] + _min(offset, 0), 3335 a_shape[axis2] - _max(offset, 0))) 3336 return lax.slice_in_dim(d, 0, diag_size, axis=-1) 3337 3338 3339@_wraps(np.diag) 3340def diag(v, k=0): 3341 _check_arraylike("diag", v) 3342 v_shape = shape(v) 3343 if len(v_shape) == 1: 3344 zero = lambda x: lax.full_like(x, shape=(), fill_value=0) 3345 n = v_shape[0] + _abs(k) 3346 v = lax.pad(v, zero(v), ((_max(0, k), _max(0, -k), 0),)) 3347 return where(eye(n, k=k, dtype=bool), v, zeros_like(v)) 3348 elif len(v_shape) == 2: 3349 return diagonal(v, offset=k) 3350 else: 3351 raise ValueError("diag input must be 1d or 2d") 3352 3353_SCALAR_VALUE_DOC="""\ 3354This differs from np.diagflat for some scalar values of v, 3355jax always returns a two-dimensional array, whereas numpy may 3356return a scalar depending on the type of v. 3357""" 3358 3359@_wraps(np.diagflat, lax_description=_SCALAR_VALUE_DOC) 3360def diagflat(v, k=0): 3361 _check_arraylike("diagflat", v) 3362 v = ravel(v) 3363 v_length = len(v) 3364 adj_length = v_length + _abs(k) 3365 res = zeros(adj_length*adj_length, dtype=v.dtype) 3366 i = arange(0, adj_length-_abs(k)) 3367 if (k >= 0): 3368 fi = i+k+i*adj_length 3369 else: 3370 fi = i+(i-k)*adj_length 3371 res = ops.index_update(res, ops.index[fi], v) 3372 res = res.reshape(adj_length,adj_length) 3373 return res 3374 3375 3376@_wraps(np.polyval) 3377def polyval(p, x): 3378 if isinstance(p, np.poly1d): 3379 p = np.asarray(p) 3380 if isinstance(x, np.poly1d): 3381 y = 0 3382 else: 3383 y = zeros_like(x) 3384 for i in range(len(p)): 3385 y = y * x + p[i] 3386 return y 3387 3388@_wraps(np.polyadd) 3389def polyadd(a1, a2): 3390 a1 = asarray(a1) 3391 a2 = asarray(a2) 3392 3393 if a2.shape[0] <= a1.shape[0]: 3394 return a1.at[-a2.shape[0]:].add(a2) 3395 else: 3396 return a2.at[-a1.shape[0]:].add(a1) 3397 3398 3399@_wraps(np.polyder) 3400def polyder(p, m=1): 3401 p = asarray(p) 3402 if m < 0: 3403 raise ValueError("Order of derivative must be positive") 3404 if m == 0: 3405 return p 3406 if m % 1: 3407 raise ValueError("m must be an integer") 3408 coeff = (arange(len(p), m, -1) - 1 - arange(m)[:, newaxis]).prod(0) 3409 return p[:-m] * coeff 3410 3411@_wraps(np.trim_zeros) 3412def trim_zeros(filt, trim='fb'): 3413 filt = core.concrete_or_error(asarray, filt, 3414 "Error arose in the `filt` argument of trim_zeros()") 3415 nz = asarray(filt) == 0 3416 if all(nz): 3417 return empty(0, _dtype(filt)) 3418 start = argmin(nz) if 'f' in trim.lower() else 0 3419 end = argmin(nz[::-1]) if 'b' in trim.lower() else 0 3420 return filt[start:len(filt) - end] 3421 3422_LEADING_ZEROS_DOC="""\ 3423Setting trim_leading_zeros=True makes the output match that of numpy. 3424But prevents the function from being able to be used in compiled code. 3425""" 3426 3427@_wraps(np.polymul, lax_description=_LEADING_ZEROS_DOC) 3428def polymul(a1, a2, *, trim_leading_zeros=False): 3429 if isinstance(a1, np.poly1d): 3430 a1 = asarray(a1) 3431 if isinstance(a2, np.poly1d): 3432 a2 = asarray(a2) 3433 if trim_leading_zeros and (len(a1) > 1 or len(a2) > 1): 3434 a1, a2 = trim_zeros(a1, trim='f'), trim_zeros(a2, trim='f') 3435 if len(a1) == 0: 3436 a1 = asarray([0.]) 3437 if len(a2) == 0: 3438 a2 = asarray([0.]) 3439 val = convolve(a1, a2, mode='full') 3440 return val 3441 3442@_wraps(np.polysub) 3443def polysub(a1, a2): 3444 return polyadd(asarray(a1), -asarray(a2)) 3445 3446 3447@_wraps(np.append) 3448def append(arr, values, axis: Optional[int] = None): 3449 if axis is None: 3450 return concatenate([ravel(arr), ravel(values)], 0) 3451 else: 3452 return concatenate([arr, values], axis=axis) 3453 3454 3455@_wraps(np.apply_along_axis) 3456def apply_along_axis(func1d, axis: int, arr, *args, **kwargs): 3457 num_dims = ndim(arr) 3458 axis = _canonicalize_axis(axis, num_dims) 3459 func = lambda arr: func1d(arr, *args, **kwargs) 3460 for i in range(1, num_dims - axis): 3461 func = jax.vmap(func, in_axes=i, out_axes=-1) 3462 for i in range(axis): 3463 func = jax.vmap(func, in_axes=0, out_axes=0) 3464 return func(arr) 3465 3466 3467@_wraps(np.apply_over_axes) 3468def apply_over_axes(func, a, axes): 3469 for axis in axes: 3470 b = func(a, axis=axis) 3471 if b.ndim == a.ndim: 3472 a = b 3473 elif b.ndim == a.ndim - 1: 3474 a = expand_dims(b, axis) 3475 else: 3476 raise ValueError("function is not returning an array of the correct shape") 3477 return a 3478 3479 3480### Tensor contraction operations 3481 3482 3483@_wraps(np.dot, lax_description=_PRECISION_DOC) 3484def dot(a, b, *, precision=None): # pylint: disable=missing-docstring 3485 _check_arraylike("dot", a, b) 3486 a, b = _promote_dtypes(a, b) 3487 a_ndim, b_ndim = ndim(a), ndim(b) 3488 if a_ndim == 0 or b_ndim == 0: 3489 return lax.mul(a, b) 3490 if _max(a_ndim, b_ndim) <= 2: 3491 return lax.dot(a, b, precision=precision) 3492 3493 if b_ndim == 1: 3494 contract_dims = ((a_ndim - 1,), (0,)) 3495 else: 3496 contract_dims = ((a_ndim - 1,), (b_ndim - 2,)) 3497 batch_dims = ((), ()) 3498 return lax.dot_general(a, b, (contract_dims, batch_dims), precision) 3499 3500 3501@_wraps(np.matmul, lax_description=_PRECISION_DOC) 3502def matmul(a, b, *, precision=None): # pylint: disable=missing-docstring 3503 _check_arraylike("matmul", a, b) 3504 for i, x in enumerate((a, b)): 3505 if ndim(x) < 1: 3506 msg = (f"matmul input operand {i} must have ndim at least 1, " 3507 f"but it has ndim {ndim(x)}") 3508 raise ValueError(msg) 3509 3510 a, b = _promote_dtypes(a, b) 3511 3512 a_is_mat, b_is_mat = (ndim(a) > 1), (ndim(b) > 1) 3513 a_batch_dims = shape(a)[:-2] if a_is_mat else () 3514 b_batch_dims = shape(b)[:-2] if b_is_mat else () 3515 num_batch_dims = _max(len(a_batch_dims), len(b_batch_dims)) 3516 a_batch_dims = (None,) * (num_batch_dims - len(a_batch_dims)) + a_batch_dims 3517 b_batch_dims = (None,) * (num_batch_dims - len(b_batch_dims)) + b_batch_dims 3518 3519 # Dimensions to squeeze from the inputs. 3520 a_squeeze = [] 3521 b_squeeze = [] 3522 3523 # Positions of batch dimensions in squeezed inputs. 3524 a_batch = [] 3525 b_batch = [] 3526 3527 # Desired index in final output of each kind of dimension, in the order that 3528 # lax.dot_general will emit them. 3529 idx_batch = [] 3530 idx_a_other = [] # other = non-batch, non-contracting. 3531 idx_b_other = [] 3532 for i, (ba, bb) in enumerate(zip(a_batch_dims, b_batch_dims)): 3533 if ba is None: 3534 idx_b_other.append(i) 3535 elif bb is None: 3536 idx_a_other.append(i) 3537 elif ba == 1: 3538 idx_b_other.append(i) 3539 a_squeeze.append(len(idx_batch) + len(idx_a_other) + len(a_squeeze)) 3540 elif bb == 1: 3541 idx_a_other.append(i) 3542 b_squeeze.append(len(idx_batch) + len(idx_b_other) + len(b_squeeze)) 3543 elif ba == bb: 3544 a_batch.append(len(idx_batch) + len(idx_a_other)) 3545 b_batch.append(len(idx_batch) + len(idx_b_other)) 3546 idx_batch.append(i) 3547 else: 3548 raise ValueError("Incompatible shapes for matmul arguments: {} and {}" 3549 .format(shape(a), shape(b))) 3550 3551 if a_is_mat: idx_a_other.append(num_batch_dims) 3552 if b_is_mat: idx_b_other.append(num_batch_dims + a_is_mat) 3553 perm = np.argsort(np.concatenate([idx_batch, idx_a_other, idx_b_other])) 3554 3555 a = lax.squeeze(a, tuple(a_squeeze)) 3556 b = lax.squeeze(b, tuple(b_squeeze)) 3557 out = lax.dot_general( 3558 a, b, (((ndim(a) - 1,), (ndim(b) - 1 - b_is_mat,)), (a_batch, b_batch)), 3559 precision=precision) 3560 return lax.transpose(out, perm) 3561 3562 3563@_wraps(np.vdot, lax_description=_PRECISION_DOC) 3564def vdot(a, b, *, precision=None): 3565 _check_arraylike("vdot", a, b) 3566 if issubdtype(_dtype(a), complexfloating): 3567 a = conj(a) 3568 return dot(a.ravel(), b.ravel(), precision=precision) 3569 3570 3571@_wraps(np.tensordot, lax_description=_PRECISION_DOC) 3572def tensordot(a, b, axes=2, *, precision=None): 3573 _check_arraylike("tensordot", a, b) 3574 a_ndim = ndim(a) 3575 b_ndim = ndim(b) 3576 3577 a, b = _promote_dtypes(a, b) 3578 if type(axes) is int: 3579 if axes > _min(a_ndim, b_ndim): 3580 msg = "Number of tensordot axes (axes {}) exceeds input ranks ({} and {})" 3581 raise TypeError(msg.format(axes, a.shape, b.shape)) 3582 contracting_dims = tuple(range(a_ndim - axes, a_ndim)), tuple(range(axes)) 3583 elif type(axes) in (list, tuple) and len(axes) == 2: 3584 ax1, ax2 = axes 3585 if type(ax1) == type(ax2) == int: 3586 contracting_dims = ((_canonicalize_axis(ax1, a_ndim),), 3587 (_canonicalize_axis(ax2, b_ndim),)) 3588 elif type(ax1) in (list, tuple) and type(ax2) in (list, tuple): 3589 if len(ax1) != len(ax2): 3590 msg = "tensordot requires axes lists to have equal length, got {} and {}." 3591 raise TypeError(msg.format(ax1, ax2)) 3592 contracting_dims = (tuple(_canonicalize_axis(i, a_ndim) for i in ax1), 3593 tuple(_canonicalize_axis(i, b_ndim) for i in ax2)) 3594 else: 3595 msg = ("tensordot requires both axes lists to be either ints, tuples or " 3596 "lists, got {} and {}") 3597 raise TypeError(msg.format(ax1, ax2)) 3598 else: 3599 msg = ("tensordot axes argument must be an int, a pair of ints, or a pair " 3600 "of lists/tuples of ints.") 3601 raise TypeError(msg) 3602 return lax.dot_general(a, b, (contracting_dims, ((), ())), 3603 precision=precision) 3604 3605 3606@_wraps(np.einsum, lax_description=_PRECISION_DOC) 3607def einsum(*operands, out=None, optimize='greedy', precision=None, 3608 _use_xeinsum=False): 3609 if out is not None: 3610 raise NotImplementedError("The 'out' argument to jnp.einsum is not supported.") 3611 3612 if (_use_xeinsum or isinstance(operands[0], str) and '{' in operands[0] and 3613 len(operands[1:]) == 2): 3614 return lax.xeinsum(*operands) 3615 3616 optimize = 'greedy' if optimize is True else optimize 3617 # using einsum_call=True here is an internal api for opt_einsum 3618 operands, contractions = opt_einsum.contract_path( 3619 *operands, einsum_call=True, use_blas=True, optimize=optimize) 3620 contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions) 3621 return _einsum(operands, contractions, precision) 3622 3623@_wraps(np.einsum_path) 3624def einsum_path(subscripts, *operands, optimize='greedy'): 3625 # using einsum_call=True here is an internal api for opt_einsum 3626 return opt_einsum.contract_path(subscripts, *operands, optimize=optimize) 3627 3628def _removechars(s, chars): 3629 return s.translate(str.maketrans(dict.fromkeys(chars))) 3630 3631@partial(jit, static_argnums=(1, 2)) 3632def _einsum(operands: Sequence, 3633 contractions: Sequence[Tuple[Tuple[int, ...], FrozenSet[str], str]], 3634 precision): 3635 operands = list(_promote_dtypes(*operands)) 3636 def sum(x, axes): 3637 return lax.reduce(x, np.array(0, x.dtype), 3638 lax.add if x.dtype != bool_ else lax.bitwise_or, axes) 3639 3640 def sum_uniques(operand, names, uniques): 3641 if uniques: 3642 axes = [names.index(name) for name in uniques] 3643 operand = sum(operand, axes) 3644 names = _removechars(names, uniques) 3645 return operand, names 3646 3647 def sum_repeats(operand, names, counts, keep_names): 3648 for name, count in counts.items(): 3649 if count > 1: 3650 axes = [i for i, n in enumerate(names) if n == name] 3651 eye = lax._delta(operand.dtype, operand.shape, axes) 3652 if name not in keep_names: 3653 operand = sum(operand * eye, axes) 3654 names = names.replace(name, '') 3655 else: 3656 operand = sum(operand * eye, axes[:-1]) 3657 names = names.replace(name, '', count - 1) 3658 return operand, names 3659 3660 def filter_singleton_dims(operand, names, other_shape, other_names): 3661 s = shape(operand) 3662 new_shape = [] 3663 new_names = [] 3664 for i, d in enumerate(names): 3665 other_i = other_names.find(d) 3666 if s[i] != 1 or other_i == -1 or other_shape[other_i] == 1: 3667 new_shape.append(s[i]) 3668 new_names.append(d) 3669 return reshape(operand, tuple(new_shape)), "".join(new_names) 3670 3671 for operand_indices, contracted_names_set, einstr in contractions: 3672 contracted_names = sorted(contracted_names_set) 3673 input_str, result_names = einstr.split('->') 3674 input_names = input_str.split(',') 3675 3676 # switch on the number of operands to be processed in this loop iteration. 3677 # every case here sets 'operand' and 'names'. 3678 if len(operand_indices) == 1: 3679 operand = operands.pop(operand_indices[0]) 3680 names, = input_names 3681 counts = collections.Counter(names) 3682 3683 # sum out unique contracted indices with a single reduce-sum 3684 uniques = [name for name in contracted_names if counts[name] == 1] 3685 operand, names = sum_uniques(operand, names, uniques) 3686 3687 # for every repeated index, do a contraction against an identity matrix 3688 operand, names = sum_repeats(operand, names, counts, result_names) 3689 3690 elif len(operand_indices) == 2: 3691 lhs, rhs = map(operands.pop, operand_indices) 3692 lhs_names, rhs_names = input_names 3693 3694 # handle cases where one side of a contracting or batch dimension is 1 3695 # but its counterpart is not. 3696 lhs, lhs_names = filter_singleton_dims(lhs, lhs_names, shape(rhs), 3697 rhs_names) 3698 rhs, rhs_names = filter_singleton_dims(rhs, rhs_names, shape(lhs), 3699 lhs_names) 3700 3701 lhs_counts = collections.Counter(lhs_names) 3702 rhs_counts = collections.Counter(rhs_names) 3703 3704 # sum out unique contracted indices in lhs and rhs 3705 lhs_uniques = [name for name in contracted_names 3706 if lhs_counts[name] == 1 and rhs_counts[name] == 0] 3707 lhs, lhs_names = sum_uniques(lhs, lhs_names, lhs_uniques) 3708 3709 rhs_uniques = [name for name in contracted_names 3710 if rhs_counts[name] == 1 and lhs_counts[name] == 0] 3711 rhs, rhs_names = sum_uniques(rhs, rhs_names, rhs_uniques) 3712 3713 # for every repeated index, contract against an identity matrix 3714 lhs, lhs_names = sum_repeats(lhs, lhs_names, lhs_counts, 3715 result_names + rhs_names) 3716 rhs, rhs_names = sum_repeats(rhs, rhs_names, rhs_counts, 3717 result_names + lhs_names) 3718 3719 lhs_or_rhs_names = set(lhs_names) | set(rhs_names) 3720 contracted_names = [x for x in contracted_names if x in lhs_or_rhs_names] 3721 lhs_and_rhs_names = set(lhs_names) & set(rhs_names) 3722 batch_names = [x for x in result_names if x in lhs_and_rhs_names] 3723 3724 lhs_batch, rhs_batch = unzip2((lhs_names.find(n), rhs_names.find(n)) 3725 for n in batch_names) 3726 3727 # NOTE(mattjj): this can fail non-deterministically in python3, maybe 3728 # due to opt_einsum 3729 assert _all( 3730 name in lhs_names and name in rhs_names and 3731 lhs.shape[lhs_names.index(name)] == rhs.shape[rhs_names.index(name)] 3732 for name in contracted_names) 3733 3734 # contract using lax.dot_general 3735 batch_names_str = ''.join(batch_names) 3736 lhs_cont, rhs_cont = unzip2((lhs_names.index(n), rhs_names.index(n)) 3737 for n in contracted_names) 3738 dimension_numbers = ((lhs_cont, rhs_cont), (lhs_batch, rhs_batch)) 3739 operand = lax.dot_general(lhs, rhs, dimension_numbers, precision) 3740 deleted_names = batch_names_str + ''.join(contracted_names) 3741 names = (batch_names_str + _removechars(lhs_names, deleted_names) 3742 + _removechars(rhs_names, deleted_names)) 3743 else: 3744 raise NotImplementedError # if this is actually reachable, open an issue! 3745 3746 # the resulting 'operand' with axis labels 'names' should be a permutation 3747 # of the desired result 3748 assert len(names) == len(result_names) == len(set(names)) 3749 assert set(names) == set(result_names) 3750 if names != result_names: 3751 perm = tuple([names.index(name) for name in result_names]) 3752 operand = lax.transpose(operand, perm) 3753 operands.append(operand) # used in next iteration 3754 3755 return operands[0] 3756 3757 3758def _movechars(s, src, dst): 3759 """Helper for einsum string munging, like moveaxis on identifier strings.""" 3760 chars = [c for i, c in enumerate(s) if i not in src] 3761 for i, j in sorted(zip(dst, src)): 3762 chars.insert(i, s[j]) 3763 return ''.join(chars) 3764 3765 3766@_wraps(np.inner, lax_description=_PRECISION_DOC) 3767def inner(a, b, *, precision=None): 3768 if ndim(a) == 0 or ndim(b) == 0: 3769 return a * b 3770 return tensordot(a, b, (-1, -1), precision=precision) 3771 3772 3773@_wraps(np.outer) 3774def outer(a, b, out=None): 3775 if out is not None: 3776 raise NotImplementedError("The 'out' argument to jnp.outer is not supported.") 3777 a, b = _promote_dtypes(a, b) 3778 return ravel(a)[:, None] * ravel(b)[None, :] 3779 3780@partial(jit, static_argnums=(2, 3, 4)) 3781def _cross(a, b, axisa, axisb, axisc): 3782 a = moveaxis(a, axisa, -1) 3783 b = moveaxis(b, axisb, -1) 3784 3785 if a.shape[-1] not in (2, 3) or b.shape[-1] not in (2, 3): 3786 raise ValueError("Dimension must be either 2 or 3 for cross product") 3787 3788 if a.shape[-1] == 2 and b.shape[-1] == 2: 3789 return a[..., 0] * b[..., 1] - a[..., 1] * b[..., 0] 3790 3791 a0 = a[..., 0] 3792 a1 = a[..., 1] 3793 a2 = a[..., 2] if a.shape[-1] == 3 else zeros_like(a0) 3794 b0 = b[..., 0] 3795 b1 = b[..., 1] 3796 b2 = b[..., 2] if b.shape[-1] == 3 else zeros_like(b0) 3797 c = array([a1 * b2 - a2 * b1, a2 * b0 - a0 * b2, a0 * b1 - a1 * b0]) 3798 return moveaxis(c, 0, axisc) 3799 3800@_wraps(np.cross) 3801def cross(a, b, axisa: int = -1, axisb: int = -1, axisc: int = -1, 3802 axis: Optional[int] = None): 3803 if axis is not None: 3804 axisa = axis 3805 axisb = axis 3806 axisc = axis 3807 return _cross(a, b, axisa, axisb, axisc) 3808 3809@_wraps(np.kron) 3810def kron(a, b): 3811 a, b = _promote_dtypes(a, b) 3812 if ndim(a) < ndim(b): 3813 a = reshape(a, (1,) * (ndim(b) - ndim(a)) + shape(a)) 3814 elif ndim(b) < ndim(a): 3815 b = reshape(b, (1,) * (ndim(a) - ndim(b)) + shape(b)) 3816 a_reshaped = reshape(a, [i for d in shape(a) for i in (d, 1)]) 3817 b_reshaped = reshape(b, [i for d in shape(b) for i in (1, d)]) 3818 out_shape = tuple(np.multiply(shape(a), shape(b))) 3819 return reshape(lax.mul(a_reshaped, b_reshaped), out_shape) 3820 3821 3822@_wraps(np.vander) 3823def vander(x, N=None, increasing=False): 3824 x = asarray(x) 3825 dtype = _dtype(x) 3826 if ndim(x) != 1: 3827 raise ValueError("x must be a one-dimensional array") 3828 x_shape = shape(x) 3829 N = N or x_shape[0] 3830 if N < 0: 3831 raise ValueError("N must be nonnegative") 3832 3833 iota = lax.iota(dtype, N) 3834 if not increasing: 3835 iota = lax.sub(lax._const(iota, N - 1), iota) 3836 3837 return power(x[..., None], iota) 3838 3839 3840### Misc 3841 3842 3843@_wraps(np.argwhere) 3844def argwhere(a): 3845 result = transpose(vstack(nonzero(a))) 3846 if ndim(a) == 0: 3847 return result[:0].reshape(result.shape[0], 0) 3848 return result.reshape(result.shape[0], ndim(a)) 3849 3850 3851@_wraps(np.argmax) 3852def argmax(a, axis: Optional[int] = None, out=None): 3853 _check_arraylike("argmax", a) 3854 if out is not None: 3855 raise NotImplementedError("The 'out' argument to jnp.argmax is not supported.") 3856 if axis is None: 3857 a = ravel(a) 3858 axis = 0 3859 if a.shape[axis] == 0: 3860 raise ValueError("attempt to get argmax of an empty sequence") 3861 return lax.argmax(a, _canonicalize_axis(axis, a.ndim), int64) 3862 3863@_wraps(np.argmin) 3864def argmin(a, axis: Optional[int] = None, out=None): 3865 _check_arraylike("argmin", a) 3866 if out is not None: 3867 raise NotImplementedError("The 'out' argument to jnp.argmin is not supported.") 3868 if axis is None: 3869 a = ravel(a) 3870 axis = 0 3871 if a.shape[axis] == 0: 3872 raise ValueError("attempt to get argmin of an empty sequence") 3873 return lax.argmin(a, _canonicalize_axis(axis, a.ndim), int64) 3874 3875 3876_NANARG_DOC = """\ 3877Warning: jax.numpy.arg{} returns -1 for all-NaN slices and does not raise 3878an error. 3879""" 3880 3881@_wraps(np.nanargmax, lax_description=_NANARG_DOC.format("max")) 3882def nanargmax(a, axis: Optional[int] = None): 3883 _check_arraylike("nanargmax", a) 3884 if not issubdtype(_dtype(a), inexact): 3885 return argmax(a, axis=axis) 3886 nan_mask = isnan(a) 3887 a = where(nan_mask, -inf, a) 3888 res = argmax(a, axis=axis) 3889 return where(all(nan_mask, axis=axis), -1, res) 3890 3891@_wraps(np.nanargmin, lax_description=_NANARG_DOC.format("min")) 3892def nanargmin(a, axis: Optional[int] = None): 3893 _check_arraylike("nanargmin", a) 3894 if not issubdtype(_dtype(a), inexact): 3895 return argmin(a, axis=axis) 3896 nan_mask = isnan(a) 3897 a = where(nan_mask, inf, a) 3898 res = argmin(a, axis=axis) 3899 return where(all(nan_mask, axis=axis), -1, res) 3900 3901 3902@_wraps(np.sort) 3903def sort(a, axis: Optional[int] = -1, kind='quicksort', order=None): 3904 _check_arraylike("sort", a) 3905 if kind != 'quicksort': 3906 warnings.warn("'kind' argument to sort is ignored.") 3907 if order is not None: 3908 raise ValueError("'order' argument to sort is not supported.") 3909 3910 if axis is None: 3911 return lax.sort(a.ravel(), dimension=0) 3912 else: 3913 return lax.sort(a, dimension=_canonicalize_axis(axis, ndim(a))) 3914 3915@_wraps(np.sort_complex) 3916def sort_complex(a): 3917 _check_arraylike("sort_complex", a) 3918 a = lax.sort(a, dimension=0) 3919 return lax.convert_element_type(a, result_type(a, dtypes.canonicalize_dtype(complex_))) 3920 3921@_wraps(np.lexsort) 3922def lexsort(keys, axis=-1): 3923 keys = tuple(keys) 3924 if len(keys) == 0: 3925 raise TypeError("need sequence of keys with len > 0 in lexsort") 3926 if len({shape(key) for key in keys}) > 1: 3927 raise ValueError("all keys need to be the same shape") 3928 if ndim(keys[0]) == 0: 3929 return np.int64(0) 3930 axis = _canonicalize_axis(axis, ndim(keys[0])) 3931 iota = lax.broadcasted_iota(np.int64, shape(keys[0]), axis) 3932 return lax.sort((*keys[::-1], iota), dimension=axis, num_keys=len(keys))[-1] 3933 3934 3935@_wraps(np.argsort) 3936def argsort(a, axis: Optional[int] = -1, kind='quicksort', order=None): 3937 _check_arraylike("argsort", a) 3938 if kind != 'quicksort': 3939 warnings.warn("'kind' argument to argsort is ignored.") 3940 if order is not None: 3941 raise ValueError("'order' argument to argsort is not supported.") 3942 3943 if axis is None: 3944 return argsort(a.ravel(), 0) 3945 else: 3946 axis_num = _canonicalize_axis(axis, ndim(a)) 3947 iota = lax.broadcasted_iota(np.int64, shape(a), axis_num) 3948 _, perm = lax.sort_key_val(a, iota, dimension=axis_num) 3949 return perm 3950 3951 3952@_wraps(np.msort) 3953def msort(a): 3954 return sort(a, axis=0) 3955 3956 3957@partial(jit, static_argnums=(2,)) 3958def _roll(a, shift, axis): 3959 a = asarray(a) 3960 a_shape = shape(a) 3961 if axis is None: 3962 return lax.reshape(roll(ravel(a), shift, axis=0), a_shape) 3963 3964 a_ndim = len(a_shape) 3965 shift = asarray(shift) 3966 axis = np.asarray(axis) 3967 b_shape = lax.broadcast_shapes(shift.shape, axis.shape, (1,)) 3968 if len(b_shape) != 1: 3969 msg = "'shift' and 'axis' arguments to roll must be scalars or 1D arrays" 3970 raise ValueError(msg) 3971 3972 for x, i in zip(broadcast_to(shift, b_shape), 3973 np.broadcast_to(axis, b_shape)): 3974 i = _canonicalize_axis(i, a_ndim) 3975 x = remainder(x, (a_shape[i] or 1)) 3976 a = lax.concatenate((a, a), i) 3977 a = lax.dynamic_slice_in_dim(a, a_shape[i] - x, a_shape[i], axis=i) 3978 return a 3979 3980 3981@_wraps(np.roll) 3982def roll(a, shift, axis: Optional[Union[int, Sequence[int]]] = None): 3983 if isinstance(axis, list): 3984 axis = tuple(axis) 3985 return _roll(a, shift, axis) 3986 3987 3988@_wraps(np.rollaxis) 3989def rollaxis(a, axis: int, start=0): 3990 _check_arraylike("rollaxis", a) 3991 a_ndim = ndim(a) 3992 axis = _canonicalize_axis(axis, a_ndim) 3993 if not (-a_ndim <= start <= a_ndim): 3994 raise ValueError(f"start={start} must satisfy {-a_ndim}<=start<={a_ndim}") 3995 if start < 0: 3996 start += a_ndim 3997 if start > axis: 3998 start -= 1 3999 return moveaxis(a, axis, start) 4000 4001 4002@_wraps(np.packbits) 4003def packbits(a, axis: Optional[int] = None, bitorder='big'): 4004 a = asarray(a) 4005 if not (issubdtype(dtype(a), integer) or issubdtype(dtype(a), bool_)): 4006 raise TypeError('Expected an input array of integer or boolean data type') 4007 if bitorder not in ['little', 'big']: 4008 raise ValueError("'order' must be either 'little' or 'big'") 4009 a = (a > 0).astype('uint8') 4010 bits = arange(8, dtype='uint8') 4011 if bitorder == 'big': 4012 bits = bits[::-1] 4013 if axis is None: 4014 a = ravel(a) 4015 axis = 0 4016 a = swapaxes(a, axis, -1) 4017 4018 remainder = a.shape[-1] % 8 4019 if remainder: 4020 a = pad(a, (a.ndim - 1) * [(0, 0)] + [(0, 8 - remainder)]) 4021 4022 a = a.reshape(a.shape[:-1] + (a.shape[-1] // 8, 8)) 4023 packed = (a << bits).sum(-1).astype('uint8') 4024 return swapaxes(packed, axis, -1) 4025 4026 4027@_wraps(np.unpackbits) 4028def unpackbits(a, axis: Optional[int] = None, count=None, bitorder='big'): 4029 a = asarray(a) 4030 if dtype(a) != uint8: 4031 raise TypeError("Expected an input array of unsigned byte data type") 4032 if bitorder not in ['little', 'big']: 4033 raise ValueError("'order' must be either 'little' or 'big'") 4034 bits = asarray(1) << arange(8, dtype='uint8') 4035 if bitorder == 'big': 4036 bits = bits[::-1] 4037 if axis is None: 4038 a = a.ravel() 4039 axis = 0 4040 a = swapaxes(a, axis, -1) 4041 unpacked = ((a[..., None] & bits) > 0).astype('uint8') 4042 unpacked = unpacked.reshape(unpacked.shape[:-2] + (-1,))[..., :count] 4043 return swapaxes(unpacked, axis, -1) 4044 4045 4046@_wraps(np.take) 4047def take(a, indices, axis: Optional[int] = None, out=None, mode=None): 4048 if out is not None: 4049 raise NotImplementedError("The 'out' argument to jnp.take is not supported.") 4050 4051 a = asarray(a) 4052 indices = asarray(indices) 4053 4054 if axis is None: 4055 a = ravel(a) 4056 axis_idx = 0 4057 else: 4058 axis_idx = _canonicalize_axis(axis, ndim(a)) 4059 4060 if mode == "raise": 4061 # TODO(phawkins): we have no way to report out of bounds errors yet. 4062 raise NotImplementedError("The 'raise' mode to jnp.take is not supported.") 4063 elif mode == "wrap": 4064 indices = mod(indices, _constant_like(indices, a.shape[axis_idx])) 4065 elif mode != "clip" and mode is not None: 4066 raise ValueError("Invalid mode '{}' for np.take".format(mode)) 4067 4068 index_dims = len(shape(indices)) 4069 slice_sizes = list(shape(a)) 4070 slice_sizes[axis_idx] = _min(indices.size, 1) 4071 dnums = lax.GatherDimensionNumbers( 4072 offset_dims=tuple( 4073 list(range(axis_idx)) + 4074 list(range(axis_idx + index_dims, len(a.shape) + index_dims - 1))), 4075 collapsed_slice_dims=(axis_idx,), 4076 start_index_map=(axis_idx,)) 4077 return lax.gather(a, indices[..., None], dimension_numbers=dnums, 4078 slice_sizes=tuple(slice_sizes)) 4079 4080 4081def _normalize_index(index, axis_size): 4082 """Normalizes an index value in the range [-N, N) to the range [0, N).""" 4083 if type(axis_size) is Poly: 4084 return index + axis_size if index < 0 else index 4085 4086 return lax.select( 4087 lax.lt(index, _constant_like(index, 0)), 4088 lax.add(index, _constant_like(index, axis_size)), 4089 index) 4090 4091@partial(jit, static_argnums=(2,)) 4092def _take_along_axis(arr, indices, axis): 4093 if axis is None: 4094 if ndim(indices) != 1: 4095 msg = "take_along_axis indices must be 1D if axis=None, got shape {}" 4096 raise ValueError(msg.format(indices.shape)) 4097 return take_along_axis(arr.ravel(), indices, 0) 4098 rank = ndim(arr) 4099 if rank != ndim(indices): 4100 msg = "indices and arr must have the same number of dimensions; {} vs. {}" 4101 raise ValueError(msg.format(ndim(indices), ndim(arr))) 4102 axis = _canonicalize_axis(axis, rank) 4103 4104 def replace(tup, val): 4105 lst = list(tup) 4106 lst[axis] = val 4107 return tuple(lst) 4108 4109 use_64bit_index = _any([type(d) is Poly or d >= (1 << 31) for d in arr.shape]) 4110 index_dtype = int64 if use_64bit_index else int32 4111 indices = lax.convert_element_type(indices, index_dtype) 4112 4113 bcast_shape = lax.broadcast_shapes(replace(arr.shape, 1), replace(indices.shape, 1)) 4114 indices = broadcast_to(indices, replace(bcast_shape, indices.shape[axis])) 4115 arr = broadcast_to(arr, replace(bcast_shape, arr.shape[axis])) 4116 4117 axis_size = arr.shape[axis] 4118 arr_shape = replace(arr.shape, 1) 4119 idx_shape = indices.shape 4120 out_shape = lax.broadcast_shapes(idx_shape, arr_shape) 4121 4122 index_dims = [i for i, idx in enumerate(idx_shape) if i == axis or idx != 1] 4123 4124 gather_index_shape = tuple(np.array(out_shape)[index_dims]) + (1,) 4125 gather_indices = [] 4126 slice_sizes = [] 4127 offset_dims = [] 4128 start_index_map = [] 4129 collapsed_slice_dims = [] 4130 j = 0 4131 for i in range(rank): 4132 if i == axis: 4133 indices = _normalize_index(indices, axis_size) 4134 gather_indices.append(lax.reshape(indices, gather_index_shape)) 4135 slice_sizes.append(1) 4136 start_index_map.append(i) 4137 collapsed_slice_dims.append(i) 4138 j += 1 4139 elif idx_shape[i] != 1: 4140 iota = lax.iota(_dtype(indices), out_shape[i]) 4141 if not config.omnistaging_enabled: 4142 iota = lax.tie_in(arr, iota) 4143 iota = lax.broadcast_in_dim(iota, gather_index_shape, (j,)) 4144 gather_indices.append(iota) 4145 slice_sizes.append(1) 4146 start_index_map.append(i) 4147 collapsed_slice_dims.append(i) 4148 j += 1 4149 else: 4150 # If idx_shape[i] == 1, we can just take the entirety of the arr's axis 4151 # and avoid forming an iota index. 4152 offset_dims.append(i) 4153 slice_sizes.append(arr_shape[i]) 4154 4155 gather_indices = lax.concatenate(gather_indices, dimension=j) 4156 dnums = lax.GatherDimensionNumbers( 4157 offset_dims=tuple(offset_dims), 4158 collapsed_slice_dims=tuple(collapsed_slice_dims), 4159 start_index_map=tuple(start_index_map)) 4160 return lax.gather(arr, gather_indices, dnums, tuple(slice_sizes)) 4161 4162 4163@_wraps(getattr(np, "take_along_axis", None), update_doc=False) 4164def take_along_axis(arr, indices, axis: Optional[int]): 4165 _check_arraylike("take_along_axis", arr) 4166 return _take_along_axis(arr, indices, axis) 4167 4168 4169### SetOps 4170 4171@partial(jit, static_argnums=1) 4172def _unique1d_sorted_mask(ar, optional_indices=False): 4173 """ 4174 Helper function for unique which is jit-able 4175 """ 4176 4177 ar = asarray(ar).flatten() 4178 4179 if optional_indices: 4180 perm = ar.argsort() 4181 aux = ar[perm] 4182 else: 4183 aux = ar.sort() 4184 4185 mask = empty(aux.shape, dtype=bool_) 4186 mask = ops.index_update(mask, ops.index[:1], True) 4187 mask = ops.index_update(mask, ops.index[1:], aux[1:] != aux[:-1]) 4188 4189 if optional_indices: 4190 return aux, mask, perm 4191 else: 4192 return aux, mask 4193 4194def _unique1d(ar, return_index=False, return_inverse=False, 4195 return_counts=False): 4196 """ 4197 Find the unique elements of an array, ignoring shape. 4198 """ 4199 4200 optional_indices = return_index or return_inverse 4201 4202 if optional_indices: 4203 aux, mask, perm = _unique1d_sorted_mask(ar, optional_indices) 4204 else: 4205 aux, mask = _unique1d_sorted_mask(ar, optional_indices) 4206 4207 ret = (aux[mask],) 4208 if return_index: 4209 ret += (perm[mask],) 4210 if return_inverse: 4211 imask = cumsum(mask) - 1 4212 inv_idx = zeros(mask.shape, dtype=dtypes.canonicalize_dtype(int_)) 4213 inv_idx = ops.index_update(inv_idx, perm, imask) 4214 ret += (inv_idx,) 4215 if return_counts: 4216 idx = concatenate(nonzero(mask) + (array([mask.size]),)) 4217 ret += (diff(idx),) 4218 return ret 4219 4220@_wraps(np.unique) 4221def unique(ar, return_index=False, return_inverse=False, 4222 return_counts=False, axis: Optional[int] = None): 4223 ar = core.concrete_or_error(asarray, ar, "The error arose in jnp.unique()") 4224 4225 if iscomplexobj(ar): 4226 raise NotImplementedError( 4227 "np.unique is not implemented for complex valued arrays") 4228 4229 if axis is None: 4230 ret = _unique1d(ar, return_index, return_inverse, return_counts) 4231 if len(ret) == 1: 4232 return ret[0] 4233 else: 4234 return ret 4235 4236 raise NotImplementedError( 4237 "np.unique is not implemented for the axis argument") 4238 4239### Indexing 4240 4241def _rewriting_take(arr, idx): 4242 # Computes arr[idx]. 4243 # All supported cases of indexing can be implemented as an XLA gather, 4244 # followed by an optional reverse and broadcast_in_dim. 4245 arr = asarray(arr) 4246 treedef, static_idx, dynamic_idx = _split_index_for_jit(idx) 4247 return _gather(arr, treedef, static_idx, dynamic_idx) 4248 4249# TODO(phawkins): re-enable jit after fixing excessive recompilation for 4250# slice indexes (e.g., slice(0, 5, None), slice(10, 15, None), etc.). 4251# @partial(jit, static_argnums=(1, 2)) 4252def _gather(arr, treedef, static_idx, dynamic_idx): 4253 idx = _merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx) 4254 indexer = _index_to_gather(shape(arr), idx) # shared with _scatter_update 4255 y = arr 4256 4257 # Avoid calling gather if the slice shape is empty, both as a fast path and to 4258 # handle cases like zeros(0)[array([], int32)]. 4259 if _prod(indexer.slice_shape) == 0: 4260 return zeros_like(y, shape=indexer.slice_shape) 4261 4262 # We avoid generating a gather when indexer.gather_indices.size is empty. 4263 if indexer.gather_indices.size: 4264 y = lax.gather(y, indexer.gather_indices, indexer.dnums, 4265 indexer.gather_slice_shape) 4266 4267 # Reverses axes with negative strides. 4268 if indexer.reversed_y_dims: 4269 y = lax.rev(y, indexer.reversed_y_dims) 4270 4271 # This adds np.newaxis/None dimensions. 4272 return expand_dims(y, indexer.newaxis_dims) 4273 4274_Indexer = collections.namedtuple("_Indexer", [ 4275 # The expected shape of the slice output. 4276 "slice_shape", 4277 4278 # The slice shape to pass to lax.gather(). 4279 "gather_slice_shape", 4280 4281 # The gather indices to use. 4282 "gather_indices", 4283 4284 # A GatherDimensionNumbers object describing the gather to perform. 4285 "dnums", 4286 4287 # Slice dimensions that have negative strides, and so must be reversed after 4288 # the gather. 4289 "reversed_y_dims", 4290 4291 # Keep track of any axes created by `newaxis`. These must be inserted for 4292 # gathers and eliminated for scatters. 4293 "newaxis_dims", 4294]) 4295 4296def _split_index_for_jit(idx): 4297 """Splits indices into necessarily-static and dynamic parts. 4298 4299 Used to pass indices into `jit`-ted function. 4300 """ 4301 # Convert list indices to tuples in cases (deprecated by NumPy.) 4302 idx = _eliminate_deprecated_list_indexing(idx) 4303 4304 # Expand any (concrete) boolean indices. We can then use advanced integer 4305 # indexing logic to handle them. 4306 idx = _expand_bool_indices(idx) 4307 4308 leaves, treedef = tree_flatten(idx) 4309 dynamic = [None] * len(leaves) 4310 static = [None] * len(leaves) 4311 for i, x in enumerate(leaves): 4312 if x is Ellipsis: 4313 static[i] = x 4314 elif isinstance(x, slice): 4315 # slice objects aren't hashable. 4316 static[i] = (x.start, x.stop, x.step) 4317 else: 4318 dynamic[i] = x 4319 return treedef, tuple(static), dynamic 4320 4321def _merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx): 4322 """Recombines indices that were split by _split_index_for_jit.""" 4323 idx = [] 4324 for s, d in zip(static_idx, dynamic_idx): 4325 if d is not None: 4326 idx.append(d) 4327 elif isinstance(s, tuple): 4328 idx.append(slice(s[0], s[1], s[2])) 4329 else: 4330 idx.append(s) 4331 return treedef.unflatten(idx) 4332 4333def _int(aval): 4334 return not aval.shape and issubdtype(aval.dtype, integer) 4335 4336def _index_to_gather(x_shape, idx, normalize_indices=True): 4337 # Remove ellipses and add trailing slice(None)s. 4338 idx = _canonicalize_tuple_index(len(x_shape), idx) 4339 4340 # Check for advanced indexing: 4341 # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing 4342 4343 # Do the advanced indexing axes appear contiguously? If not, NumPy semantics 4344 # move the advanced axes to the front. 4345 advanced_axes_are_contiguous = False 4346 4347 advanced_indexes = None 4348 4349 # The positions of the advanced indexing axes in `idx`. 4350 idx_advanced_axes = [] 4351 4352 # The positions of the advanced indexes in x's shape. 4353 # collapsed, after None axes have been removed. See below. 4354 x_advanced_axes = None 4355 4356 if _is_advanced_int_indexer(idx): 4357 idx_no_nones = [(i, d) for i, d in enumerate(idx) if d is not None] 4358 advanced_pairs = ( 4359 (asarray(e), i, j) for j, (i, e) in enumerate(idx_no_nones) 4360 if isscalar(e) or isinstance(e, (Sequence, ndarray))) 4361 if normalize_indices: 4362 advanced_pairs = ((_normalize_index(e, x_shape[j]), i, j) 4363 for e, i, j in advanced_pairs) 4364 advanced_indexes, idx_advanced_axes, x_advanced_axes = zip(*advanced_pairs) 4365 advanced_axes_are_contiguous = np.all(np.diff(idx_advanced_axes) == 1) 4366 4367 x_axis = 0 # Current axis in x. 4368 y_axis = 0 # Current axis in y, before collapsing. See below. 4369 collapsed_y_axis = 0 # Current axis in y, after collapsing. 4370 4371 # Scatter dimension numbers. 4372 offset_dims = [] 4373 collapsed_slice_dims = [] 4374 start_index_map = [] 4375 4376 use_64bit_index = _any([type(d) is Poly or d >= (1 << 31) for d in x_shape]) 4377 index_dtype = int64 if use_64bit_index else int32 4378 gather_indices = np.zeros((0,), dtype=index_dtype) # use np to save a compilation 4379 4380 # We perform three transformations to y before the scatter op, in order: 4381 # First, y is broadcast to slice_shape. In general `y` only need broadcast to 4382 # the right shape. 4383 slice_shape = [] 4384 4385 # Next, y is squeezed to remove newaxis_dims. This removes np.newaxis/`None` 4386 # indices, which the scatter cannot remove itself. 4387 newaxis_dims = [] 4388 4389 # Finally, we reverse reversed_y_dims to handle slices with negative strides. 4390 reversed_y_dims = [] 4391 4392 gather_slice_shape = [] 4393 4394 for idx_pos, i in enumerate(idx): 4395 # Handle the advanced indices here if: 4396 # * the advanced indices were not contiguous and we are the start. 4397 # * we are at the position of the first advanced index. 4398 if (advanced_indexes is not None and 4399 (advanced_axes_are_contiguous and idx_pos == idx_advanced_axes[0] or 4400 not advanced_axes_are_contiguous and idx_pos == 0)): 4401 advanced_indexes = broadcast_arrays(*advanced_indexes) 4402 shape = advanced_indexes[0].shape 4403 ndim = len(shape) 4404 advanced_indexes = [ 4405 lax.convert_element_type(lax.reshape(a, shape + (1,)), index_dtype) 4406 for a in advanced_indexes] 4407 4408 # Broadcast gather_indices from [..., k] to [..., 1, 1, ..., 1, k]. 4409 gather_indices = lax.broadcast_in_dim( 4410 gather_indices, np.insert(gather_indices.shape, -1, shape), 4411 tuple(range(gather_indices.ndim - 1)) + (gather_indices.ndim + ndim - 1,)) 4412 gather_indices = concatenate([gather_indices] + advanced_indexes, -1) 4413 start_index_map.extend(x_advanced_axes) 4414 collapsed_slice_dims.extend(x_advanced_axes) 4415 slice_shape.extend(shape) 4416 y_axis += ndim 4417 collapsed_y_axis += ndim 4418 4419 # Per-index bookkeeping for advanced indexes. 4420 if idx_pos in idx_advanced_axes: 4421 x_axis += 1 4422 gather_slice_shape.append(1) 4423 continue 4424 4425 try: 4426 abstract_i = core.get_aval(i) 4427 except TypeError: 4428 abstract_i = None 4429 # Handle basic int indexes. 4430 if isinstance(abstract_i, (ConcreteArray,ShapedArray)) and _int(abstract_i): 4431 if x_shape[x_axis] == 0: 4432 # XLA gives error when indexing into an axis of size 0 4433 raise IndexError(f"index is out of bounds for axis {x_axis} with size 0") 4434 i = _normalize_index(i, x_shape[x_axis]) if normalize_indices else i 4435 if type(i) is Poly: 4436 # dummy index if i is polynomial, doesn't matter for shape inference 4437 # TODO(mattjj,j-towns,juliuskunze): revise this logic 4438 i = 0 4439 i = lax.convert_element_type(i, index_dtype) 4440 i = broadcast_to(i, tuple(gather_indices.shape[:-1]) + (1,)) 4441 gather_indices = concatenate((gather_indices, i), -1) 4442 collapsed_slice_dims.append(x_axis) 4443 gather_slice_shape.append(1) 4444 start_index_map.append(x_axis) 4445 x_axis += 1 4446 # Handle np.newaxis (None) 4447 elif i is None: 4448 slice_shape.append(1) 4449 newaxis_dims.append(y_axis) 4450 y_axis += 1 4451 # Handle slice(None) 4452 elif _is_slice_none(i): 4453 slice_shape.append(x_shape[x_axis]) 4454 gather_slice_shape.append(x_shape[x_axis]) 4455 offset_dims.append(collapsed_y_axis) 4456 collapsed_y_axis += 1 4457 y_axis += 1 4458 x_axis += 1 4459 # Handle slice index (only static, otherwise an error is raised) 4460 elif isinstance(i, slice): 4461 if not _all(elt is None or type(elt) is Poly 4462 or type(core.get_aval(elt)) is ConcreteArray 4463 for elt in (i.start, i.stop, i.step)): 4464 msg = ("Array slice indices must have static start/stop/step to be used " 4465 "with NumPy indexing syntax. To index a statically sized " 4466 "array at a dynamic position, try lax.dynamic_slice/" 4467 "dynamic_update_slice (JAX does not support dynamically sized " 4468 "arrays within JIT compiled functions).") 4469 raise IndexError(msg) 4470 start, limit, stride, needs_rev = _static_idx(i, x_shape[x_axis]) 4471 if needs_rev: 4472 reversed_y_dims.append(collapsed_y_axis) 4473 if stride == 1: 4474 i = lax.convert_element_type(start, index_dtype) 4475 i = broadcast_to(i, tuple(gather_indices.shape[:-1]) + (1,)) 4476 gather_indices = concatenate((gather_indices, i), -1) 4477 slice_shape.append(limit - start) 4478 gather_slice_shape.append(limit - start) 4479 offset_dims.append(collapsed_y_axis) 4480 start_index_map.append(x_axis) 4481 else: 4482 i = arange(start, limit, stride, dtype=index_dtype) 4483 size = i.shape[0] 4484 slice_shape.append(size) 4485 gather_slice_shape.append(1) 4486 gather_indices_shape = tuple(gather_indices.shape[:-1]) + (size,) 4487 i = lax.broadcast_in_dim( 4488 i, shape=gather_indices_shape + (1,), 4489 broadcast_dimensions=(len(gather_indices_shape) - 1,)) 4490 gather_indices = lax.broadcast_in_dim( 4491 gather_indices, 4492 shape=gather_indices_shape + (len(start_index_map),), 4493 broadcast_dimensions=( 4494 tuple(range(len(gather_indices_shape) - 1)) + 4495 (len(gather_indices_shape),))) 4496 gather_indices = concatenate( 4497 (gather_indices, i), len(gather_indices_shape)) 4498 start_index_map.append(x_axis) 4499 collapsed_slice_dims.append(x_axis) 4500 4501 collapsed_y_axis += 1 4502 y_axis += 1 4503 x_axis += 1 4504 else: 4505 if (abstract_i is not None and 4506 not (issubdtype(abstract_i.dtype, integer) or issubdtype(abstract_i.dtype, bool_))): 4507 msg = ("Indexer must have integer or boolean type, got indexer " 4508 "with type {} at position {}, indexer value {}") 4509 raise TypeError(msg.format(abstract_i.dtype.name, idx_pos, i)) 4510 4511 msg = "Indexing mode not yet supported. Open a feature request!\n{}" 4512 raise IndexError(msg.format(idx)) 4513 4514 dnums = lax.GatherDimensionNumbers( 4515 offset_dims = tuple(offset_dims), 4516 collapsed_slice_dims = tuple(sorted(collapsed_slice_dims)), 4517 start_index_map = tuple(start_index_map) 4518 ) 4519 return _Indexer( 4520 slice_shape=slice_shape, 4521 newaxis_dims=tuple(newaxis_dims), 4522 gather_slice_shape=gather_slice_shape, 4523 reversed_y_dims=reversed_y_dims, 4524 dnums=dnums, 4525 gather_indices=gather_indices) 4526 4527def _should_unpack_list_index(x): 4528 """Helper for _eliminate_deprecated_list_indexing.""" 4529 return (isinstance(x, ndarray) and np.ndim(x) != 0 4530 or isinstance(x, (Sequence, slice)) 4531 or x is Ellipsis or x is None) 4532 4533def _eliminate_deprecated_list_indexing(idx): 4534 # "Basic slicing is initiated if the selection object is a non-array, 4535 # non-tuple sequence containing slice objects, [Ellipses, or newaxis 4536 # objects]". Detects this and raises a TypeError. 4537 if not isinstance(idx, tuple): 4538 if isinstance(idx, Sequence) and not isinstance(idx, ndarray): 4539 # As of numpy 1.16, some non-tuple sequences of indices result in a warning, while 4540 # others are converted to arrays, based on a set of somewhat convoluted heuristics 4541 # (See https://github.com/numpy/numpy/blob/v1.19.2/numpy/core/src/multiarray/mapping.c#L179-L343) 4542 # In JAX, we raise an informative TypeError for *all* non-tuple sequences. 4543 if _any(_should_unpack_list_index(i) for i in idx): 4544 msg = ("Using a non-tuple sequence for multidimensional indexing is not allowed; " 4545 "use `arr[tuple(seq)]` instead of `arr[seq]`. " 4546 "See https://github.com/google/jax/issues/4564 for more information.") 4547 else: 4548 msg = ("Using a non-tuple sequence for multidimensional indexing is not allowed; " 4549 "use `arr[array(seq)]` instead of `arr[seq]`. " 4550 "See https://github.com/google/jax/issues/4564 for more information.") 4551 raise TypeError(msg) 4552 else: 4553 idx = (idx,) 4554 return idx 4555 4556def _expand_bool_indices(idx): 4557 """Converts concrete bool indexes into advanced integer indexes.""" 4558 out = [] 4559 for i in idx: 4560 try: 4561 abstract_i = core.get_aval(i) 4562 except TypeError: 4563 abstract_i = None 4564 if (isinstance(abstract_i, ShapedArray) and issubdtype(abstract_i.dtype, bool_) 4565 or isinstance(i, list) and _all(not _shape(e) and issubdtype(_dtype(e), bool_) 4566 for e in i)): 4567 if isinstance(i, list): 4568 i = array(i) 4569 abstract_i = core.get_aval(i) 4570 4571 if not type(abstract_i) is ConcreteArray: 4572 # TODO(mattjj): improve this error by tracking _why_ the indices are not 4573 # concrete 4574 raise IndexError("Array boolean indices must be concrete.") 4575 else: 4576 out.extend(np.where(i)) 4577 else: 4578 out.append(i) 4579 return tuple(out) 4580 4581def _is_slice_none(idx): 4582 """Return True if idx is equal to slice(None), False otherwise.""" 4583 if isinstance(idx, slice): 4584 return idx.start is None and idx.stop is None and idx.step is None 4585 4586# TODO(mattjj): clean up this logic 4587def _is_advanced_int_indexer(idx): 4588 """Returns True if idx should trigger int array indexing, False otherwise.""" 4589 # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing 4590 assert isinstance(idx, tuple) 4591 if _all(np.ndim(elt) == 0 for elt in idx): 4592 return False 4593 return _all(e is None or e is Ellipsis or isinstance(e, slice) 4594 or _is_int_arraylike(e) for e in idx) 4595 4596def _is_int_arraylike(x): 4597 """Returns True if x is array-like with integer dtype, False otherwise.""" 4598 return (isinstance(x, int) and not isinstance(x, bool) 4599 or issubdtype(getattr(x, "dtype", None), np.integer) 4600 or isinstance(x, (list, tuple)) and _all(_is_int_arraylike(e) for e in x)) 4601 4602 4603def _canonicalize_tuple_index(arr_ndim, idx): 4604 """Helper to remove Ellipsis and add in the implicit trailing slice(None).""" 4605 len_without_none = _sum(1 for e in idx if e is not None and e is not Ellipsis) 4606 if len_without_none > arr_ndim: 4607 msg = "Too many indices for array: {} non-None/Ellipsis indices for dim {}." 4608 raise IndexError(msg.format(len_without_none, arr_ndim)) 4609 ellipses = (i for i, elt in enumerate(idx) if elt is Ellipsis) 4610 ellipsis_index = next(ellipses, None) 4611 if ellipsis_index is not None: 4612 if next(ellipses, None) is not None: 4613 msg = "Multiple ellipses (...) not supported: {}." 4614 raise IndexError(msg.format(list(map(type, idx)))) 4615 colons = (slice(None),) * (arr_ndim - len_without_none) 4616 idx = idx[:ellipsis_index] + colons + idx[ellipsis_index + 1:] 4617 elif len_without_none < arr_ndim: 4618 colons = (slice(None),) * (arr_ndim - len_without_none) 4619 idx = tuple(idx) + colons 4620 return idx 4621 4622def _polymorphic_slice_indices(idx: slice, size: Union[int, Poly]): 4623 # like idx.indices(size), but allows for polymorphic indices and size 4624 # see https://github.com/python/cpython/blob/6d6508765514c7c10719478a0430f5e47c9a96ac/Objects/sliceobject.c#L372 4625 assert isinstance(idx, slice) 4626 4627 step = 1 if idx.step is None else idx.step 4628 step_is_negative = step < 0 4629 lower = -1 if step_is_negative else 0 4630 upper = size + lower 4631 4632 def sanitize(index, default): 4633 if index is None: 4634 return default 4635 elif type(index) is Poly: 4636 return index 4637 elif index < 0: 4638 return _max(index + size, lower) 4639 else: 4640 return _min(index, upper) 4641 4642 start = sanitize(idx.start, default=upper if step_is_negative else lower) 4643 stop = sanitize(idx.stop, default=lower if step_is_negative else upper) 4644 return start, stop, step 4645 4646def _static_idx(idx: slice, size: Union[int, Poly]): 4647 """Helper function to compute the static slice start/limit/stride values.""" 4648 if _any(type(s) is Poly for s in (idx.start, idx.stop, idx.step, size)): 4649 start, stop, step = _polymorphic_slice_indices(idx, size) 4650 elif isinstance(size, int): 4651 start, stop, step = idx.indices(size) 4652 else: 4653 raise TypeError(size) 4654 4655 if type(start) is not Poly and type(stop) is not Poly: 4656 if (step < 0 and stop >= start) or (step > 0 and start >= stop): 4657 return 0, 0, 1, False # sliced to size zero 4658 4659 if step > 0: 4660 return start, stop, step, False 4661 else: 4662 k = (start - stop - 1) % (-step) 4663 return stop + k + 1, start + 1, -step, True 4664 4665 4666blackman = _wrap_numpy_nullary_function(np.blackman) 4667bartlett = _wrap_numpy_nullary_function(np.bartlett) 4668hamming = _wrap_numpy_nullary_function(np.hamming) 4669hanning = _wrap_numpy_nullary_function(np.hanning) 4670# TODO: lower `kaiser` via lax to allow non-constant beta values. 4671kaiser = _wrap_numpy_nullary_function(np.kaiser) 4672 4673def _gcd_cond_fn(xs): 4674 x1, x2 = xs 4675 return any(x2 != 0) 4676 4677def _gcd_body_fn(xs): 4678 x1, x2 = xs 4679 x1, x2 = (where(x2 != 0, x2, x1), 4680 where(x2 != 0, lax.rem(x1, x2), lax._const(x2, 0))) 4681 return (where(x1 < x2, x2, x1), where(x1 < x2, x1, x2)) 4682 4683@_wraps(getattr(np, "gcd", None)) 4684def gcd(x1, x2): 4685 _check_arraylike("gcd", x1, x2) 4686 if (not issubdtype(_dtype(x1), integer) or 4687 not issubdtype(_dtype(x2), integer)): 4688 raise ValueError("Arguments to jax.numpy.gcd must be integers.") 4689 x1, x2 = _promote_dtypes(x1, x2) 4690 x1, x2 = broadcast_arrays(x1, x2) 4691 gcd, _ = lax.while_loop(_gcd_cond_fn, _gcd_body_fn, (abs(x1), abs(x2))) 4692 return gcd 4693 4694 4695@_wraps(getattr(np, "lcm", None)) 4696def lcm(x1, x2): 4697 _check_arraylike("lcm", x1, x2) 4698 x1, x2 = _promote_dtypes(x1, x2) 4699 d = gcd(x1, x2) 4700 return where(d == 0, lax._const(d, 0), 4701 abs(multiply(x1, floor_divide(x2, d)))) 4702 4703 4704@_wraps(np.extract) 4705def extract(condition, arr): 4706 return compress(ravel(condition), ravel(arr)) 4707 4708 4709@_wraps(np.compress) 4710def compress(condition, a, axis: Optional[int] = None, out=None): 4711 if out is not None: 4712 raise NotImplementedError("The 'out' argument to jnp.compress is not supported.") 4713 if ndim(condition) != 1: 4714 raise ValueError("condition must be a 1D array") 4715 condition = asarray(condition).astype(bool) 4716 a = asarray(a) 4717 if axis is None: 4718 axis = 0 4719 a = ravel(a) 4720 else: 4721 a = moveaxis(a, axis, 0) 4722 condition, extra = condition[:a.shape[0]], condition[a.shape[0]:] 4723 if any(extra): 4724 raise ValueError("condition contains entries that are out of bounds") 4725 a = a[:condition.shape[0]] 4726 return moveaxis(a[condition], 0, axis) 4727 4728 4729@_wraps(np.cov) 4730def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None, 4731 aweights=None): 4732 if y is not None: raise NotImplementedError( 4733 "jax.numpy.cov not implemented for nontrivial y. " 4734 "Open a feature request at https://github.com/google/jax/issues !") 4735 4736 m, = _promote_args_inexact("cov", m) 4737 4738 if m.ndim > 2: 4739 raise ValueError("m has more than 2 dimensions") # same as numpy error 4740 X = atleast_2d(m) 4741 if not rowvar and X.shape[0] != 1: 4742 X = X.T 4743 if X.shape[0] == 0: 4744 return array([]).reshape(0, 0) 4745 if ddof is None: 4746 ddof = 1 if bias == 0 else 0 4747 4748 w = None 4749 if fweights is not None: 4750 _check_arraylike("cov", fweights) 4751 if ndim(fweights) > 1: 4752 raise RuntimeError("cannot handle multidimensional fweights") 4753 if shape(fweights)[0] != X.shape[1]: 4754 raise RuntimeError("incompatible numbers of samples and fweights") 4755 if not issubdtype(_dtype(fweights), integer): 4756 raise TypeError("fweights must be integer.") 4757 # Ensure positive fweights; note that numpy raises an error on negative fweights. 4758 w = asarray(abs(fweights)) 4759 if aweights is not None: 4760 _check_arraylike("cov", aweights) 4761 if ndim(aweights) > 1: 4762 raise RuntimeError("cannot handle multidimensional aweights") 4763 if shape(aweights)[0] != X.shape[1]: 4764 raise RuntimeError("incompatible numbers of samples and aweights") 4765 # Ensure positive aweights: note that numpy raises an error for negative aweights. 4766 aweights = abs(aweights) 4767 w = aweights if w is None else w * aweights 4768 4769 avg, w_sum = average(X, axis=1, weights=w, returned=True) 4770 w_sum = w_sum[0] 4771 4772 if w is None: 4773 f = X.shape[1] - ddof 4774 elif ddof == 0: 4775 f = w_sum 4776 elif aweights is None: 4777 f = w_sum - ddof 4778 else: 4779 f = w_sum - ddof * sum(w * aweights) / w_sum 4780 4781 X = X - avg[:, None] 4782 X_T = X.T if w is None else (X * w).T 4783 return true_divide(dot(X, X_T.conj()), f).squeeze() 4784 4785 4786@_wraps(np.corrcoef) 4787def corrcoef(x, y=None, rowvar=True): 4788 _check_arraylike("corrcoef", x) 4789 c = cov(x, y, rowvar) 4790 if len(shape(c)) == 0: 4791 # scalar - this should yield nan for values (nan/nan, inf/inf, 0/0), 1 otherwise 4792 return divide(c, c) 4793 d = diag(c) 4794 stddev = sqrt(real(d)) 4795 c = divide(c, stddev[:,None]) 4796 c = divide(c, stddev[None,:]) 4797 4798 real_part = clip(real(c), -1, 1) 4799 if iscomplexobj(c): 4800 complex_part = clip(imag(c), -1, 1) 4801 c = lax.complex(real_part, complex_part) 4802 else: 4803 c = real_part 4804 return c 4805 4806 4807@_wraps(getattr(np, "quantile", None)) 4808def quantile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, 4809 overwrite_input=False, interpolation="linear", keepdims=False): 4810 _check_arraylike("quantile", a, q) 4811 if overwrite_input or out is not None: 4812 msg = ("jax.numpy.quantile does not support overwrite_input=True or " 4813 "out != None") 4814 raise ValueError(msg) 4815 return _quantile(a, q, axis, interpolation, keepdims, False) 4816 4817@_wraps(getattr(np, "nanquantile", None)) 4818def nanquantile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None, 4819 out=None, overwrite_input=False, interpolation="linear", 4820 keepdims=False): 4821 _check_arraylike("nanquantile", a, q) 4822 if overwrite_input or out is not None: 4823 msg = ("jax.numpy.nanquantile does not support overwrite_input=True or " 4824 "out != None") 4825 raise ValueError(msg) 4826 return _quantile(a, q, axis, interpolation, keepdims, True) 4827 4828 4829@partial(jit, static_argnums=(2, 3, 4, 5)) 4830def _quantile(a, q, axis, interpolation, keepdims, squash_nans): 4831 if interpolation not in ["linear", "lower", "higher", "midpoint", "nearest"]: 4832 raise ValueError("interpolation can only be 'linear', 'lower', 'higher', " 4833 "'midpoint', or 'nearest'") 4834 a = asarray(a, dtype=promote_types(_dtype(a), float32)) 4835 q = asarray(q, dtype=promote_types(_dtype(q), float32)) 4836 if axis is None: 4837 a = ravel(a) 4838 axis = 0 4839 elif isinstance(axis, tuple): 4840 raise NotImplementedError("Tuple values for axis are not implemented") 4841 else: 4842 axis = _canonicalize_axis(axis, ndim(a)) 4843 4844 q_shape = shape(q) 4845 q_ndim = ndim(q) 4846 if q_ndim > 1: 4847 raise ValueError("q must be have rank <= 1, got shape {}".format(shape(q))) 4848 4849 a_shape = shape(a) 4850 a = lax.sort(a, dimension=axis) 4851 4852 if squash_nans: 4853 counts = sum(logical_not(isnan(a)), axis=axis, dtype=q.dtype, 4854 keepdims=keepdims) 4855 shape_after_reduction = counts.shape 4856 q = lax.expand_dims( 4857 q, tuple(range(q_ndim, len(shape_after_reduction) + q_ndim))) 4858 counts = lax.expand_dims(counts, tuple(range(q_ndim))) 4859 q = lax.mul(q, lax.sub(counts, _constant_like(q, 1))) 4860 low = lax.floor(q) 4861 high = lax.ceil(q) 4862 high_weight = lax.sub(q, low) 4863 low_weight = lax.sub(_constant_like(high_weight, 1), high_weight) 4864 4865 low = lax.max(_constant_like(low, 0), lax.min(low, counts - 1)) 4866 high = lax.max(_constant_like(high, 0), lax.min(high, counts - 1)) 4867 low = lax.convert_element_type(low, int64) 4868 high = lax.convert_element_type(high, int64) 4869 out_shape = q_shape + shape_after_reduction 4870 index = [lax.broadcasted_iota(int64, out_shape, dim + q_ndim) 4871 for dim in range(len(shape_after_reduction))] 4872 if keepdims: 4873 index[axis] = low 4874 else: 4875 index.insert(axis, low) 4876 low_value = a[tuple(index)] 4877 index[axis] = high 4878 high_value = a[tuple(index)] 4879 else: 4880 n = a_shape[axis] 4881 q = lax.mul(q, _constant_like(q, n - 1)) 4882 low = lax.floor(q) 4883 high = lax.ceil(q) 4884 high_weight = lax.sub(q, low) 4885 low_weight = lax.sub(_constant_like(high_weight, 1), high_weight) 4886 4887 low = lax.clamp(_constant_like(low, 0), low, _constant_like(low, n - 1)) 4888 high = lax.clamp(_constant_like(high, 0), high, _constant_like(high, n - 1)) 4889 low = lax.convert_element_type(low, int64) 4890 high = lax.convert_element_type(high, int64) 4891 4892 slice_sizes = list(a_shape) 4893 slice_sizes[axis] = 1 4894 dnums = lax.GatherDimensionNumbers( 4895 offset_dims=tuple(range( 4896 q_ndim, 4897 len(a_shape) + q_ndim if keepdims else len(a_shape) + q_ndim - 1)), 4898 collapsed_slice_dims=() if keepdims else (axis,), 4899 start_index_map=(axis,)) 4900 low_value = lax.gather(a, low[..., None], dimension_numbers=dnums, 4901 slice_sizes=slice_sizes) 4902 high_value = lax.gather(a, high[..., None], dimension_numbers=dnums, 4903 slice_sizes=slice_sizes) 4904 if q_ndim == 1: 4905 low_weight = lax.broadcast_in_dim(low_weight, low_value.shape, 4906 broadcast_dimensions=(0,)) 4907 high_weight = lax.broadcast_in_dim(high_weight, high_value.shape, 4908 broadcast_dimensions=(0,)) 4909 4910 if interpolation == "linear": 4911 result = lax.add(lax.mul(low_value.astype(q.dtype), low_weight), 4912 lax.mul(high_value.astype(q.dtype), high_weight)) 4913 elif interpolation == "lower": 4914 result = low_value 4915 elif interpolation == "higher": 4916 result = high_value 4917 elif interpolation == "nearest": 4918 pred = lax.le(high_weight, _constant_like(high_weight, 0.5)) 4919 result = lax.select(pred, low_value, high_value) 4920 elif interpolation == "midpoint": 4921 result = lax.mul(lax.add(low_value, high_value), _constant_like(low_value, 0.5)) 4922 else: 4923 raise ValueError(f"interpolation={interpolation!r} not recognized") 4924 4925 return lax.convert_element_type(result, a.dtype) 4926 4927 4928@partial(jit, static_argnums=2) 4929@partial(vectorize, excluded={0, 2}) 4930def _searchsorted(a, v, side): 4931 if len(a) == 0: 4932 return 0 4933 op = operator.le if side == 'left' else operator.lt 4934 4935 def body_fun(i, state): 4936 low, high = state 4937 mid = (low + high) // 2 4938 go_left = op(v, a[mid]) 4939 return (where(go_left, low, mid), where(go_left, mid, high)) 4940 4941 n_levels = int(np.ceil(np.log2(len(a) + 1))) 4942 return lax.fori_loop(0, n_levels, body_fun, (0, len(a)))[1] 4943 4944 4945@_wraps(np.searchsorted) 4946def searchsorted(a, v, side='left', sorter=None): 4947 if side not in ['left', 'right']: 4948 raise ValueError(f"{side!r} is an invalid value for keyword 'side'") 4949 if sorter is not None: 4950 raise NotImplementedError("sorter is not implemented") 4951 a = asarray(a) 4952 v = asarray(v) 4953 if ndim(a) != 1: 4954 raise ValueError("a should be 1-dimensional") 4955 return _searchsorted(a, v, side) 4956 4957 4958@_wraps(np.digitize) 4959def digitize(x, bins, right=False): 4960 if len(bins) == 0: 4961 return zeros(x, dtype=dtypes.canonicalize_dtype(int_)) 4962 side = 'right' if not right else 'left' 4963 return where( 4964 bins[-1] >= bins[0], 4965 searchsorted(bins, x, side=side), 4966 len(bins) - searchsorted(bins[::-1], x, side=side) 4967 ) 4968 4969_PIECEWISE_DOC = """\ 4970Unlike `np.piecewise`, :py:func:`jax.numpy.piecewise` requires functions in 4971`funclist` to be traceable by JAX, as it is implemeted via :func:`jax.lax.switch`. 4972See the :func:`jax.lax.switch` documentation for more information. 4973""" 4974 4975@_wraps(np.piecewise, lax_description=_PIECEWISE_DOC) 4976def piecewise(x, condlist, funclist, *args, **kw): 4977 _check_arraylike("piecewise", x) 4978 condlist = array(condlist, dtype=bool_) 4979 nc, nf = len(condlist), len(funclist) 4980 if nf == nc + 1: 4981 funclist = funclist[-1:] + funclist[:-1] 4982 elif nf == nc: 4983 funclist = [0] + list(funclist) 4984 else: 4985 raise ValueError(f"with {nc} condition(s), either {nc} or {nc+1} functions are expected; got {nf}") 4986 indices = argmax(cumsum(vstack([zeros_like(condlist[:1]), condlist]), 0), 0) 4987 dtype = _dtype(x) 4988 def _call(f): 4989 return lambda x: f(x, *args, **kw).astype(dtype) 4990 def _const(v): 4991 return lambda x: full_like(x, v) 4992 funclist = [_call(f) if callable(f) else _const(f) for f in funclist] 4993 return vectorize(lax.switch, excluded=(1,))(indices, funclist, x) 4994 4995 4996@_wraps(np.percentile) 4997def percentile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None, 4998 out=None, overwrite_input=False, interpolation="linear", 4999 keepdims=False): 5000 _check_arraylike("percentile", a) 5001 q = true_divide(asarray(q), float32(100.0)) 5002 return quantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input, 5003 interpolation=interpolation, keepdims=keepdims) 5004 5005@_wraps(np.nanpercentile) 5006def nanpercentile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None, 5007 out=None, overwrite_input=False, interpolation="linear", 5008 keepdims=False): 5009 _check_arraylike("nanpercentile", a) 5010 q = true_divide(asarray(q), float32(100.0)) 5011 return nanquantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input, 5012 interpolation=interpolation, keepdims=keepdims) 5013 5014@_wraps(np.median) 5015def median(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, 5016 overwrite_input=False, keepdims=False): 5017 _check_arraylike("median", a) 5018 return quantile(a, 0.5, axis=axis, out=out, overwrite_input=overwrite_input, 5019 keepdims=keepdims, interpolation='midpoint') 5020 5021@_wraps(np.nanmedian) 5022def nanmedian(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, 5023 overwrite_input=False, keepdims=False): 5024 _check_arraylike("nanmedian", a) 5025 return nanquantile(a, 0.5, axis=axis, out=out, 5026 overwrite_input=overwrite_input, keepdims=keepdims, 5027 interpolation='midpoint') 5028 5029 5030def _astype(arr, dtype): 5031 lax._check_user_dtype_supported(dtype, "astype") 5032 return lax.convert_element_type(arr, dtype) 5033 5034 5035def _nbytes(arr): 5036 return size(arr) * _dtype(arr).itemsize 5037 5038 5039def _view(arr, dtype=None, type=None): 5040 lax._check_user_dtype_supported(dtype, "view") 5041 if type is not None: 5042 raise NotImplementedError("`type` argument of array.view()") 5043 if dtype is None: 5044 return arr 5045 arr_dtype = _dtype(arr) 5046 if arr_dtype == dtype: 5047 return arr 5048 # bool is implemented as lax:PRED, which is not compatible with lax.bitcast_convert_type. 5049 # We work around this by casting bool to uint8. 5050 if arr_dtype == bool_: 5051 arr = arr.astype(uint8) 5052 nbits_in = 8 * arr_dtype.itemsize 5053 nbits_out = 8 * _dtype(dtype).itemsize 5054 if nbits_in == nbits_out: 5055 if dtype == bool_: 5056 return lax.bitcast_convert_type(arr, uint8).astype(dtype) 5057 return lax.bitcast_convert_type(arr, dtype) 5058 if nbits_out > nbits_in and (shape(arr)[-1] * nbits_in) % nbits_out != 0: 5059 raise ValueError("When changing to a larger dtype, its size must be a divisor " 5060 "of the total size in bytes of the last axis of the array.") 5061 byte_dtypes = {8: uint8, 16: uint16, 32: uint32, 64: uint64} 5062 if nbits_in not in byte_dtypes: 5063 raise NotImplementedError(f"arr.view() for arr.dtype={arr_dtype}") 5064 if nbits_out not in byte_dtypes: 5065 raise NotImplementedError(f"arr.view(dtype) for dtype={dtype}") 5066 dt_in = byte_dtypes[nbits_in] 5067 dt_out = byte_dtypes[nbits_out] 5068 arr_bytes = lax.bitcast_convert_type(arr, dt_in) 5069 if nbits_in < nbits_out: 5070 shifts = arange(0, nbits_out, nbits_in, dtype=dt_out) 5071 arr_bytes = arr_bytes.reshape(arr.shape[:-1] + (-1, nbits_out // nbits_in)).astype(dt_out) 5072 arr_bytes = (arr_bytes << shifts).sum(-1).astype(dt_out) 5073 else: 5074 shifts = arange(0, nbits_in, nbits_out, dtype=dt_in) 5075 arr_bytes = ((arr_bytes[..., newaxis] >> shifts) & iinfo(dt_out).max).astype(dt_out) 5076 arr_bytes = arr_bytes.reshape(arr_bytes.shape[:-2] + (-1,)) 5077 if dtype == bool_: 5078 return lax.bitcast_convert_type(arr_bytes, uint8).astype(dtype) 5079 return lax.bitcast_convert_type(arr_bytes, dtype) 5080 5081### track unimplemented functions 5082 5083_NOT_IMPLEMENTED_DESC = """ 5084*** This function is not yet implemented by jax.numpy, and will raise NotImplementedError *** 5085""" 5086 5087def _not_implemented(fun): 5088 @_wraps(fun, update_doc=False, lax_description=_NOT_IMPLEMENTED_DESC) 5089 def wrapped(*args, **kwargs): 5090 msg = "Numpy function {} not yet implemented" 5091 raise NotImplementedError(msg.format(fun)) 5092 return wrapped 5093 5094 5095### add method and operator overloads to arraylike classes 5096 5097# We add operator overloads to DeviceArray and ShapedArray. These method and 5098# operator overloads mainly just forward calls to the corresponding lax_numpy 5099# functions, which can themselves handle instances from any of these classes. 5100 5101_scalar_types = (int, float, complex, np.generic) 5102 5103def _defer_to_unrecognized_arg(binary_op): 5104 # Ensure that other array types have the chance to override arithmetic. 5105 def deferring_binary_op(self, other): 5106 if not isinstance(other, _scalar_types + _arraylike_types + (core.Tracer,)): 5107 return NotImplemented 5108 return binary_op(self, other) 5109 return deferring_binary_op 5110 5111def _swap_args(f): 5112 return lambda x, y: f(y, x) 5113 5114def _unimplemented_setitem(self, i, x): 5115 msg = ("'{}' object does not support item assignment. JAX arrays are " 5116 "immutable; perhaps you want jax.ops.index_update or " 5117 "jax.ops.index_add instead?") 5118 raise TypeError(msg.format(type(self))) 5119 5120def _operator_round(number, ndigits=None): 5121 out = round(number, decimals=ndigits or 0) 5122 # If `ndigits` is None, for a builtin float round(7.5) returns an integer. 5123 return out.astype(int) if ndigits is None else out 5124 5125_operators = { 5126 "getitem": _rewriting_take, 5127 "setitem": _unimplemented_setitem, 5128 "neg": negative, 5129 "pos": positive, 5130 "eq": _defer_to_unrecognized_arg(equal), 5131 "ne": _defer_to_unrecognized_arg(not_equal), 5132 "lt": _defer_to_unrecognized_arg(less), 5133 "le": _defer_to_unrecognized_arg(less_equal), 5134 "gt": _defer_to_unrecognized_arg(greater), 5135 "ge": _defer_to_unrecognized_arg(greater_equal), 5136 "abs": abs, 5137 "add": _defer_to_unrecognized_arg(add), 5138 "radd": _defer_to_unrecognized_arg(add), 5139 "sub": _defer_to_unrecognized_arg(subtract), 5140 "rsub": _defer_to_unrecognized_arg(_swap_args(subtract)), 5141 "mul": _defer_to_unrecognized_arg(multiply), 5142 "rmul": _defer_to_unrecognized_arg(multiply), 5143 "div": _defer_to_unrecognized_arg(divide), 5144 "rdiv": _defer_to_unrecognized_arg(_swap_args(divide)), 5145 "truediv": _defer_to_unrecognized_arg(true_divide), 5146 "rtruediv": _defer_to_unrecognized_arg(_swap_args(true_divide)), 5147 "floordiv": _defer_to_unrecognized_arg(floor_divide), 5148 "rfloordiv": _defer_to_unrecognized_arg(_swap_args(floor_divide)), 5149 "divmod": _defer_to_unrecognized_arg(divmod), 5150 "rdivmod": _defer_to_unrecognized_arg(_swap_args(divmod)), 5151 "mod": _defer_to_unrecognized_arg(mod), 5152 "rmod": _defer_to_unrecognized_arg(_swap_args(mod)), 5153 "pow": _defer_to_unrecognized_arg(power), 5154 "rpow": _defer_to_unrecognized_arg(_swap_args(power)), 5155 "matmul": _defer_to_unrecognized_arg(matmul), 5156 "rmatmul": _defer_to_unrecognized_arg(_swap_args(matmul)), 5157 "and": _defer_to_unrecognized_arg(bitwise_and), 5158 "rand": _defer_to_unrecognized_arg(bitwise_and), 5159 "or": _defer_to_unrecognized_arg(bitwise_or), 5160 "ror": _defer_to_unrecognized_arg(bitwise_or), 5161 "xor": _defer_to_unrecognized_arg(bitwise_xor), 5162 "rxor": _defer_to_unrecognized_arg(bitwise_xor), 5163 "invert": bitwise_not, 5164 "lshift": _defer_to_unrecognized_arg(left_shift), 5165 "rshift": _defer_to_unrecognized_arg(right_shift), 5166 "rlshift": _defer_to_unrecognized_arg(_swap_args(left_shift)), 5167 "rrshift": _defer_to_unrecognized_arg(_swap_args(right_shift)), 5168 "round": _operator_round, 5169} 5170 5171# These numpy.ndarray methods are just refs to an equivalent numpy function 5172_nondiff_methods = ["all", "any", "argmax", "argmin", "argpartition", "argsort", 5173 "nonzero", "searchsorted", "round"] 5174_diff_methods = ["clip", "conj", "conjugate", "cumprod", "cumsum", 5175 "diagonal", "dot", "max", "mean", "min", "prod", "ptp", 5176 "ravel", "repeat", "sort", "squeeze", "std", "sum", 5177 "swapaxes", "take", "tile", "trace", "transpose", "var"] 5178 5179# These methods are mentioned explicitly by nondiff_methods, so we create 5180# _not_implemented implementations of them here rather than in __init__.py. 5181# TODO(phawkins): implement these. 5182argpartition = _not_implemented(np.argpartition) 5183_NOT_IMPLEMENTED = ['argpartition'] 5184 5185# Set up operator, method, and property forwarding on Tracer instances containing 5186# ShapedArray avals by following the forwarding conventions for Tracer. 5187# Forward operators using a single-underscore-prefix naming convention: 5188for operator_name, function in _operators.items(): 5189 setattr(ShapedArray, "_{}".format(operator_name), staticmethod(function)) 5190# Forward methods and properties using core.aval_method and core.aval_property: 5191for method_name in _nondiff_methods + _diff_methods: 5192 setattr(ShapedArray, method_name, core.aval_method(globals()[method_name])) 5193setattr(ShapedArray, "reshape", core.aval_method(_reshape)) 5194setattr(ShapedArray, "flatten", core.aval_method(ravel)) 5195setattr(ShapedArray, "T", core.aval_property(transpose)) 5196setattr(ShapedArray, "real", core.aval_property(real)) 5197setattr(ShapedArray, "imag", core.aval_property(imag)) 5198setattr(ShapedArray, "astype", core.aval_method(_astype)) 5199setattr(ShapedArray, "view", core.aval_method(_view)) 5200setattr(ShapedArray, "nbytes", core.aval_property(_nbytes)) 5201 5202 5203# Forward operators, methods, and properties on DeviceArray to lax_numpy 5204# functions (with no Tracers involved; this forwarding is direct) 5205for device_array in [_DeviceArray, _CppDeviceArray]: 5206 for operator_name, function in _operators.items(): 5207 setattr(device_array, "__{}__".format(operator_name), function) 5208 for method_name in _nondiff_methods + _diff_methods: 5209 setattr(device_array, method_name, globals()[method_name]) 5210 setattr(device_array, "reshape", _reshape) 5211 setattr(device_array, "flatten", ravel) 5212 setattr(device_array, "T", property(transpose)) 5213 setattr(device_array, "real", property(real)) 5214 setattr(device_array, "imag", property(imag)) 5215 setattr(device_array, "astype", _astype) 5216 setattr(device_array, "view", _view) 5217 setattr(device_array, "nbytes", property(_nbytes)) 5218 5219 5220# Experimental support for NumPy's module dispatch with NEP-37. 5221# Currently requires https://github.com/seberg/numpy-dispatch 5222_JAX_ARRAY_TYPES = (DeviceArray, core.Tracer) 5223_HANDLED_ARRAY_TYPES = _JAX_ARRAY_TYPES + (np.ndarray,) 5224 5225def __array_module__(self, types): 5226 if builtins.all(issubclass(t, _HANDLED_ARRAY_TYPES) for t in types): 5227 return jax.numpy 5228 else: 5229 return NotImplemented 5230 5231setattr(ShapedArray, "_array_module", staticmethod(__array_module__)) 5232setattr(_DeviceArray, "__array_module__", __array_module__) 5233setattr(_CppDeviceArray, "__array_module__", __array_module__) 5234 5235 5236# Extra methods that are handy 5237setattr(ShapedArray, "broadcast", core.aval_method(lax.broadcast)) 5238setattr(ShapedArray, "broadcast_in_dim", core.aval_method(lax.broadcast_in_dim)) 5239setattr(ShapedArray, "split", core.aval_method(split)) 5240for device_array in [_DeviceArray, _CppDeviceArray]: 5241 setattr(device_array, "broadcast", lax.broadcast) 5242 setattr(device_array, "broadcast_in_dim", lax.broadcast_in_dim) 5243 setattr(device_array, "split", split) 5244 5245def _compress_method(a, condition, axis=None, out=None): 5246 return compress(condition, a, axis, out) 5247 5248setattr(ShapedArray, "compress", _compress_method) 5249setattr(_DeviceArray, "compress", _compress_method) 5250setattr(_CppDeviceArray, "compress", _compress_method) 5251 5252@partial(jit, static_argnums=(1,2,3)) 5253def _multi_slice(arr, 5254 start_indices: Tuple[Tuple[int, ...]], 5255 limit_indices: Tuple[Tuple[int, ...]], 5256 removed_dims: Tuple[Tuple[int, ...]]): 5257 """Extracts multiple slices from `arr`. 5258 5259 This is used to shard DeviceArray arguments to pmap. It's implemented as a 5260 DeviceArray method here to avoid circular imports. 5261 """ 5262 results = [] 5263 for starts, limits, removed in safe_zip(start_indices, limit_indices, removed_dims): 5264 sliced = lax.slice(arr, starts, limits) 5265 if removed: 5266 sliced = sliced.reshape(np.delete(sliced.shape, removed_dims)) 5267 results.append(sliced) 5268 return results 5269setattr(_DeviceArray, "_multi_slice", _multi_slice) 5270setattr(_CppDeviceArray, "_multi_slice", _multi_slice) 5271 5272 5273# Syntactic sugar for scatter operations. 5274class _IndexUpdateHelper: 5275 # Note: this docstring will appear as the docstring for the `at` property. 5276 """Indexable helper object to call indexed update functions. 5277 5278 The `at` property is syntactic sugar for calling the indexed update functions 5279 defined in :mod:`jax.ops`, and acts as a pure equivalent of in-place 5280 modificatons. 5281 5282 In particular: 5283 - ``x = x.at[idx].set(y)`` is a pure equivalent of ``x[idx] = y``. 5284 - ``x = x.at[idx].add(y)`` is a pure equivalent of ``x[idx] += y``. 5285 - ``x = x.at[idx].mul(y)`` is a pure equivalent of ``x[idx] *= y``. 5286 - ``x = x.at[idx].min(y)`` is a pure equivalent of 5287 ``x[idx] = minimum(x[idx], y)``. 5288 - ``x = x.at[idx].max(y)`` is a pure equivalent of 5289 ``x[idx] = maximum(x[idx], y)``. 5290 """ 5291 __slots__ = ("array",) 5292 5293 def __init__(self, array): 5294 self.array = array 5295 5296 def __getitem__(self, index): 5297 return _IndexUpdateRef(self.array, index) 5298 5299 def __repr__(self): 5300 return f"_IndexUpdateHelper({repr(self.array)})" 5301 5302 5303class _IndexUpdateRef: 5304 """Helper object to call indexed update functions for an (advanced) index. 5305 5306 This object references a source array and a specific indexer into that array. 5307 Methods on this object return copies of the source array that have been 5308 modified at the positions specified by the indexer. 5309 """ 5310 __slots__ = ("array", "index") 5311 5312 def __init__(self, array, index): 5313 self.array = array 5314 self.index = index 5315 5316 def __repr__(self): 5317 return f"_IndexUpdateRef({repr(self.array)}, {repr(self.index)})" 5318 5319 def set(self, values, indices_are_sorted=False, unique_indices=False): 5320 """Pure equivalent of ``x[idx] = y``. 5321 5322 ``x.at[idx].set(y)`` is syntactic sugar for 5323 ``jax.ops.index_update(x, jax.ops.index[idx], y)``, and 5324 returns the value of ``x`` that would result from the NumPy-style 5325 :mod:indexed assignment <numpy.doc.indexing>` ``x[idx] = y``. 5326 5327 See :mod:`jax.ops` for details. 5328 """ 5329 return ops.index_update(self.array, self.index, values, 5330 indices_are_sorted=indices_are_sorted, 5331 unique_indices=unique_indices) 5332 5333 def add(self, values, indices_are_sorted=False, unique_indices=False): 5334 """Pure equivalent of ``x[idx] += y``. 5335 5336 ``x.at[idx].add(y)`` is syntactic sugar for 5337 ``jax.ops.index_add(x, jax.ops.index[idx], y)``, and 5338 returns the value of ``x`` that would result from the NumPy-style 5339 :mod:indexed assignment <numpy.doc.indexing>` ``x[idx] += y``. 5340 5341 See :mod:`jax.ops` for details. 5342 """ 5343 return ops.index_add(self.array, self.index, values, 5344 indices_are_sorted=indices_are_sorted, 5345 unique_indices=unique_indices) 5346 5347 def mul(self, values, indices_are_sorted=False, unique_indices=False): 5348 """Pure equivalent of ``x[idx] += y``. 5349 5350 ``x.at[idx].mul(y)`` is syntactic sugar for 5351 ``jax.ops.index_mul(x, jax.ops.index[idx], y)``, and 5352 returns the value of ``x`` that would result from the NumPy-style 5353 :mod:indexed assignment <numpy.doc.indexing>` ``x[idx] *= y``. 5354 5355 See :mod:`jax.ops` for details. 5356 """ 5357 return ops.index_mul(self.array, self.index, values, 5358 indices_are_sorted=indices_are_sorted, 5359 unique_indices=unique_indices) 5360 5361 def min(self, values, indices_are_sorted=False, unique_indices=False): 5362 """Pure equivalent of ``x[idx] = minimum(x[idx], y)``. 5363 5364 ``x.at[idx].min(y)`` is syntactic sugar for 5365 ``jax.ops.index_min(x, jax.ops.index[idx], y)``, and 5366 returns the value of ``x`` that would result from the NumPy-style 5367 :mod:indexed assignment <numpy.doc.indexing>` 5368 ``x[idx] = minimum(x[idx], y)``. 5369 5370 See :mod:`jax.ops` for details. 5371 """ 5372 return ops.index_min(self.array, self.index, values, 5373 indices_are_sorted=indices_are_sorted, 5374 unique_indices=unique_indices) 5375 5376 def max(self, values, indices_are_sorted=False, unique_indices=False): 5377 """Pure equivalent of ``x[idx] = maximum(x[idx], y)``. 5378 5379 ``x.at[idx].max(y)`` is syntactic sugar for 5380 ``jax.ops.index_max(x, jax.ops.index[idx], y)``, and 5381 returns the value of ``x`` that would result from the NumPy-style 5382 :mod:indexed assignment <numpy.doc.indexing>` 5383 ``x[idx] = maximum(x[idx], y)``. 5384 5385 See :mod:`jax.ops` for details. 5386 """ 5387 return ops.index_max(self.array, self.index, values, 5388 indices_are_sorted=indices_are_sorted, 5389 unique_indices=unique_indices) 5390 5391setattr(_DeviceArray, "at", property(_IndexUpdateHelper)) 5392setattr(_CppDeviceArray, "at", property(_IndexUpdateHelper)) 5393setattr(ShapedArray, "at", core.aval_property(_IndexUpdateHelper)) 5394