1# Copyright 2018 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# pytype: skip-file
16"""
17Implements the NumPy API, using the primitives in :mod:`jax.lax`.
18
19NumPy operations are implemented in Python in terms of the primitive operations
20in :mod:`jax.lax`. Since NumPy operations are not primitive and instead are
21implemented in terms of :mod:`jax.lax` operations, we do not need to define
22transformation rules such as gradient or batching rules. Instead,
23transformations for NumPy primitives can be derived from the transformation
24rules for the underlying :code:`lax` primitives.
25"""
26
27import builtins
28import collections
29import operator
30import os
31import types
32from typing import Sequence, FrozenSet, Optional, Tuple, Union, Iterable, cast
33from textwrap import dedent as _dedent
34import warnings
35
36import numpy as np
37import opt_einsum
38
39import jax
40from jax import jit, custom_jvp
41from .vectorize import vectorize
42from .util import _wraps
43from jax import core
44from jax import dtypes
45from jax.core import UnshapedArray, ShapedArray, ConcreteArray, canonicalize_shape
46from jax.config import flags, config
47from jax.interpreters.xla import DeviceArray, _DeviceArray, _CppDeviceArray
48from jax.interpreters.masking import Poly
49from jax import lax
50from jax._src.lax.lax import _device_put_raw
51from jax import ops
52from jax._src.util import (partial, unzip2, prod as _prod, subvals, safe_zip,
53                           canonicalize_axis as _canonicalize_axis)
54from jax.tree_util import tree_leaves, tree_flatten
55
56FLAGS = flags.FLAGS
57flags.DEFINE_enum(
58    'jax_numpy_rank_promotion', os.getenv('JAX_NUMPY_RANK_PROMOTION', 'allow'),
59    enum_values=['allow', 'warn', 'raise'],
60    help=
61    'Control NumPy-style automatic rank promotion broadcasting '
62    '("allow", "warn", or "raise").')
63
64newaxis = None
65
66# Common docstring additions:
67
68_PRECISION_DOC = """\
69In addition to the original NumPy arguments listed below, also supports
70``precision`` for extra control over matrix-multiplication precision
71on supported devices. ``precision`` may be set to ``None``, which means
72default precision for the backend, a ``lax.Precision`` enum value
73(``Precision.DEFAULT``, ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple
74of two ``lax.Precision`` enums indicating separate precision for each argument.
75"""
76
77# We replace some builtin names to follow Numpy's API, so we capture here.
78_abs = builtins.abs
79_all = builtins.all
80_any = builtins.any
81_max = builtins.max
82_min = builtins.min
83_sum = builtins.sum
84_divmod = builtins.divmod
85
86# NumPy constants
87
88pi = np.pi
89e = np.e
90euler_gamma = np.euler_gamma
91inf = np.inf
92NINF = np.NINF
93PZERO = np.PZERO
94NZERO = np.NZERO
95nan = np.nan
96
97# And some numpy utility functions
98set_printoptions = np.set_printoptions
99
100# We want isinstance(x, np.ndarray) checks in user code to work with the our
101# array-like types, including DeviceArray and UnshapedArray (i.e. the abstract
102# array base class). We can override the isinstance behavior directly, without
103# having the complexity of multiple inheritance on those classes, by defining
104# the ndarray class to have a metaclass with special __instancecheck__ behavior.
105_arraylike_types = (np.ndarray, UnshapedArray, DeviceArray)
106
107class _ArrayMeta(type(np.ndarray)):  # type: ignore
108  """Metaclass for overriding ndarray isinstance checks."""
109
110  def __instancecheck__(self, instance):
111    try:
112      return isinstance(instance.aval, _arraylike_types)
113    except AttributeError:
114      return isinstance(instance, _arraylike_types)
115
116class ndarray(np.ndarray, metaclass=_ArrayMeta):
117  dtype: np.dtype
118  shape: Tuple[int, ...]
119  size: int
120
121  def __init__(shape, dtype=None, buffer=None, offset=0, strides=None,
122               order=None):
123    raise TypeError("jax.numpy.ndarray() should not be instantiated explicitly."
124                    " Use jax.numpy.array, or jax.numpy.zeros instead.")
125
126
127iscomplexobj = np.iscomplexobj
128
129shape = _shape = np.shape
130ndim = _ndim = np.ndim
131size = np.size
132_dtype = dtypes.result_type
133
134# At present JAX doesn't have a reason to distinguish between scalars and arrays
135# in its object system. Further, we want JAX scalars to have the same type
136# promotion behaviors as JAX arrays. Rather than introducing a new type of JAX
137# scalar object with JAX promotion behaviors, instead we make the JAX scalar
138# types return JAX arrays when instantiated.
139
140class _ScalarMeta(type):
141  def __hash__(self):
142    return hash(self.dtype.type)
143
144  def __eq__(self, other):
145    return id(self) == id(other) or self.dtype.type == other
146
147  def __ne__(self, other):
148    return not (self == other)
149
150  def __call__(self, x):
151    return array(x, dtype=self.dtype)
152
153def _make_scalar_type(np_scalar_type):
154  return _ScalarMeta(np_scalar_type.__name__, (object,),
155                     {"dtype": np.dtype(np_scalar_type)})
156
157bool_ = _make_scalar_type(np.bool_)
158uint8 = _make_scalar_type(np.uint8)
159uint16 = _make_scalar_type(np.uint16)
160uint32 = _make_scalar_type(np.uint32)
161uint64 = _make_scalar_type(np.uint64)
162int8 = _make_scalar_type(np.int8)
163int16 = _make_scalar_type(np.int16)
164int32 = _make_scalar_type(np.int32)
165int64 = _make_scalar_type(np.int64)
166bfloat16 = _make_scalar_type(dtypes.bfloat16)
167float16 = _make_scalar_type(np.float16)
168float32 = single = _make_scalar_type(np.float32)
169float64 = double = _make_scalar_type(np.float64)
170complex64 = csingle = _make_scalar_type(np.complex64)
171complex128 = cdouble = _make_scalar_type(np.complex128)
172
173int_ = int32 if dtypes.int_ == np.int32 else int64
174float_ = float32 if dtypes.float_ == np.float32 else float64
175complex_ = complex64 if dtypes.complex_ == np.complex64 else complex128
176
177number = np.number
178inexact = np.inexact
179complexfloating = np.complexfloating
180floating = np.floating
181integer = np.integer
182signedinteger = np.signedinteger
183unsignedinteger = np.unsignedinteger
184
185flexible = np.flexible
186character = np.character
187object_ = np.object_
188
189iinfo = dtypes.iinfo
190finfo = dtypes.finfo
191
192dtype = np.dtype
193can_cast = dtypes.can_cast
194issubsctype = dtypes.issubsctype
195promote_types = dtypes.promote_types
196
197ComplexWarning = np.ComplexWarning
198
199array_str = np.array_str
200array_repr = np.array_repr
201
202save = np.save
203savez = np.savez
204load = np.load
205
206
207### utility functions
208
209_DEFAULT_TYPEMAP = {
210  np.bool_: bool_,
211  np.int_: int_,
212  np.float_: float_,
213  np.complex_: complex_
214}
215
216_INT_DTYPES = {
217  16: np.int16,
218  32: np.int32,
219  64: np.int64,
220}
221
222def _np_array(obj, dtype=None, **kwargs):
223  """Return a properly-typed numpy array.
224
225  `_np_array(obj, **kwds)` is equivalent to `np.array(obj, **kwds)`, with the
226  exception that when obj.dtype is not defined and dtype is not specified, it
227  uses Jax's default dtypes.
228  """
229  arr = np.array(obj, dtype=dtype, **kwargs)
230  obj_dtype = getattr(obj, 'dtype', None)
231  arr_dtype = np.dtype(arr.dtype).type
232  if dtype is None and obj_dtype is None and arr_dtype in _DEFAULT_TYPEMAP:
233    arr = arr.astype(_DEFAULT_TYPEMAP[arr_dtype])
234  return arr
235
236_np_asarray = partial(_np_array, copy=False)
237
238def _promote_shapes(fun_name, *args):
239  """Prepend implicit leading singleton dimensions for Numpy broadcasting."""
240  if len(args) < 2:
241    return args
242  else:
243    shapes = [shape(arg) for arg in args]
244    nonscalar_ranks = [len(shp) for shp in shapes if shp]
245    if not nonscalar_ranks or len(set(nonscalar_ranks)) == 1:
246      return args
247    else:
248      if FLAGS.jax_numpy_rank_promotion != "allow":
249        _rank_promotion_warning_or_error(fun_name, shapes)
250      result_rank = len(lax.broadcast_shapes(*shapes))
251      return [broadcast_to(arg, (1,) * (result_rank - len(shp)) + shp)
252              for arg, shp in zip(args, shapes)]
253
254def _rank_promotion_warning_or_error(fun_name, shapes):
255  if FLAGS.jax_numpy_rank_promotion == "warn":
256    msg = ("Following NumPy automatic rank promotion for {} on shapes {}. "
257           "Set the jax_numpy_rank_promotion config option to 'allow' to "
258           "disable this warning; for more information, see "
259           "https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.")
260    warnings.warn(msg.format(fun_name, ' '.join(map(str, shapes))))
261  elif FLAGS.jax_numpy_rank_promotion == "raise":
262    msg = ("Operands could not be broadcast together for {} on shapes {} "
263           "and with the config option jax_numpy_rank_promotion='raise'. "
264           "For more information, see "
265           "https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.")
266    raise ValueError(msg.format(fun_name, ' '.join(map(str, shapes))))
267
268def _promote_dtypes(*args):
269  """Convenience function to apply Numpy argument dtype promotion."""
270  # TODO(dougalm,mattjj): This is a performance bottleneck. Consider memoizing.
271  if len(args) < 2:
272    return args
273  else:
274    to_dtype = result_type(*args)
275    return [lax.convert_element_type(x, to_dtype) for x in args]
276
277def _promote_dtypes_inexact(*args):
278  """Convenience function to apply Numpy argument dtype promotion.
279
280  Promotes arguments to an inexact type."""
281  to_dtype = _to_inexact_dtype(result_type(*args))
282  return [lax.convert_element_type(x, to_dtype) for x in args]
283
284def _to_inexact_dtype(dtype):
285  """Promotes a dtype into an inexact dtype, if it is not already one."""
286  return dtype if issubdtype(dtype, inexact) else promote_types(dtype, float_)
287
288def _complex_elem_type(dtype):
289  """Returns the float type of the real/imaginary parts of a complex dtype."""
290  return np.abs(np.zeros((), dtype)).dtype
291
292def _result_dtype(op, *args):
293  """Compute result dtype of applying op to arguments with given dtypes."""
294  args = [np.ones((0,) * ndim(arg), _dtype(arg)) for arg in args]
295  return _dtype(op(*args))
296
297
298def _arraylike(x): return isinstance(x, ndarray) or isscalar(x)
299def _check_arraylike(fun_name, *args):
300  """Check if all args fit JAX's definition of arraylike (ndarray or scalar)."""
301  assert isinstance(fun_name, str), f"fun_name must be a string. Got {fun_name}"
302  if _any(not _arraylike(arg) for arg in args):
303    pos, arg = next((i, arg) for i, arg in enumerate(args)
304                    if not _arraylike(arg))
305    msg = "{} requires ndarray or scalar arguments, got {} at position {}."
306    raise TypeError(msg.format(fun_name, type(arg), pos))
307
308def _check_no_float0s(fun_name, *args):
309  """Check if none of the args have dtype float0."""
310  if _any(dtypes.dtype(arg) is dtypes.float0 for arg in args):
311    raise TypeError(
312        f"Called {fun_name} with a float0 array. "
313        "float0s do not support any operations by design because they "
314        "are not compatible with non-trivial vector spaces. No implicit dtype "
315        "conversion is done. You can use np.zeros_like(arr, dtype=np.float) "
316        "to cast a float0 array to a regular zeros array. \n"
317        "If you didn't expect to get a float0 you might have accidentally "
318        "taken a gradient with respect to an integer argument.")
319
320def _promote_args(fun_name, *args):
321  """Convenience function to apply Numpy argument shape and dtype promotion."""
322  _check_arraylike(fun_name, *args)
323  _check_no_float0s(fun_name, *args)
324  return _promote_shapes(fun_name, *_promote_dtypes(*args))
325
326def _promote_args_inexact(fun_name, *args):
327  """Convenience function to apply Numpy argument shape and dtype promotion.
328
329  Promotes non-inexact types to an inexact type."""
330  _check_arraylike(fun_name, *args)
331  _check_no_float0s(fun_name, *args)
332  return _promote_shapes(fun_name, *_promote_dtypes_inexact(*args))
333
334def _constant_like(x, const):
335  return np.array(const, dtype=_dtype(x))
336
337### implementations of numpy functions in terms of lax
338
339@_wraps(np.fmin)
340def fmin(x1, x2):
341  return where((x1 < x2) | isnan(x2), x1, x2)
342
343@_wraps(np.fmax)
344def fmax(x1, x2):
345  return where((x1 > x2) | isnan(x2), x1, x2)
346
347@_wraps(np.issubdtype)
348def issubdtype(arg1, arg2):
349  return dtypes.issubdtype(arg1, arg2)
350
351@_wraps(np.isscalar)
352def isscalar(element):
353  return dtypes.is_python_scalar(element) or np.isscalar(element)
354
355iterable = np.iterable
356
357@_wraps(np.result_type)
358def result_type(*args):
359  return dtypes.result_type(*args)
360
361def _one_to_one_unop(numpy_fn, lax_fn, promote_to_inexact=False, lax_doc=False):
362  if promote_to_inexact:
363    fn = lambda x: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x))
364  else:
365    fn = lambda x: lax_fn(*_promote_args(numpy_fn.__name__, x))
366  if lax_doc:
367    doc = _dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip()
368    return _wraps(numpy_fn, lax_description=doc)(fn)
369  else:
370    return _wraps(numpy_fn)(fn)
371
372def _one_to_one_binop(numpy_fn, lax_fn, promote_to_inexact=False, lax_doc=False):
373  if promote_to_inexact:
374    fn = lambda x1, x2: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x1, x2))
375  else:
376    fn = lambda x1, x2: lax_fn(*_promote_args(numpy_fn.__name__, x1, x2))
377  if lax_doc:
378    doc = _dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip()
379    return _wraps(numpy_fn, lax_description=doc)(fn)
380  else:
381    return _wraps(numpy_fn)(fn)
382
383def _maybe_bool_binop(numpy_fn, lax_fn, bool_lax_fn, lax_doc=False):
384  def fn(x1, x2):
385    x1, x2 = _promote_args(numpy_fn.__name__, x1, x2)
386    return lax_fn(x1, x2) if x1.dtype != bool_ else bool_lax_fn(x1, x2)
387  return _wraps(numpy_fn)(fn)
388  if lax_doc:
389    doc = _dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip()
390    return _wraps(numpy_fn, lax_description=doc)(fn)
391  else:
392    return _wraps(numpy_fn)(fn)
393
394fabs = _one_to_one_unop(np.fabs, lax.abs, True)
395bitwise_not = _one_to_one_unop(np.bitwise_not, lax.bitwise_not)
396invert = _one_to_one_unop(np.invert, lax.bitwise_not)
397negative = _one_to_one_unop(np.negative, lax.neg)
398positive = _one_to_one_unop(np.positive, lambda x: x)
399
400floor = _one_to_one_unop(np.floor, lax.floor, True)
401ceil = _one_to_one_unop(np.ceil, lax.ceil, True)
402exp = _one_to_one_unop(np.exp, lax.exp, True)
403log = _one_to_one_unop(np.log, lax.log, True)
404expm1 = _one_to_one_unop(np.expm1, lax.expm1, True)
405log1p = _one_to_one_unop(np.log1p, lax.log1p, True)
406sin = _one_to_one_unop(np.sin, lax.sin, True)
407cos = _one_to_one_unop(np.cos, lax.cos, True)
408tan = _one_to_one_unop(np.tan, lax.tan, True)
409arcsin = _one_to_one_unop(np.arcsin, lax.asin, True)
410arccos = _one_to_one_unop(np.arccos, lax.acos, True)
411arctan = _one_to_one_unop(np.arctan, lax.atan, True)
412sinh = _one_to_one_unop(np.sinh, lax.sinh, True)
413cosh = _one_to_one_unop(np.cosh, lax.cosh, True)
414arcsinh = _one_to_one_unop(np.arcsinh, lax.asinh, True)
415tanh = _one_to_one_unop(np.tanh, lax.tanh, True)
416arcsinh = _one_to_one_unop(np.arcsinh, lax.asinh, True)
417arctanh = _one_to_one_unop(np.arctanh, lax.atanh, True)
418sqrt = _one_to_one_unop(np.sqrt, lax.sqrt, True)
419
420
421add = _maybe_bool_binop(np.add, lax.add, lax.bitwise_or)
422bitwise_and = _one_to_one_binop(np.bitwise_and, lax.bitwise_and)
423bitwise_or = _one_to_one_binop(np.bitwise_or, lax.bitwise_or)
424bitwise_xor = _one_to_one_binop(np.bitwise_xor, lax.bitwise_xor)
425left_shift = _one_to_one_binop(np.left_shift, lax.shift_left)
426equal = _one_to_one_binop(np.equal, lax.eq)
427multiply = _maybe_bool_binop(np.multiply, lax.mul, lax.bitwise_and)
428not_equal = _one_to_one_binop(np.not_equal, lax.ne)
429subtract = _one_to_one_binop(np.subtract, lax.sub)
430arctan2 = _one_to_one_binop(np.arctan2, lax.atan2, True)
431minimum = _one_to_one_binop(np.minimum, lax.min)
432maximum = _one_to_one_binop(np.maximum, lax.max)
433float_power = _one_to_one_binop(np.float_power, lax.pow, True)
434nextafter = _one_to_one_binop(np.nextafter, lax.nextafter, True, True)
435
436@_wraps(np.arccosh)
437def arccosh(x):
438  # Note: arccosh is multi-valued for complex input, and lax.acosh uses a different
439  # convention than np.arccosh.
440  out = lax.acosh(*_promote_args_inexact("arccosh", x))
441  if issubdtype(out.dtype, np.complexfloating):
442    out = where(real(out) < 0, lax.neg(out), out)
443  return out
444
445def _comparison_op(numpy_fn, lax_fn):
446  def fn(x1, x2):
447    x1, x2 =  _promote_args(numpy_fn.__name__, x1, x2)
448    # Comparison on complex types are defined as a lexicographic ordering on
449    # the (real, imag) pair.
450    if issubdtype(_dtype(x1), complexfloating):
451      rx = lax.real(x1)
452      ry = lax.real(x2)
453      return lax.select(lax.eq(rx, ry), lax_fn(lax.imag(x1), lax.imag(x2)),
454                        lax_fn(rx, ry))
455    return lax_fn(x1, x2)
456  return _wraps(numpy_fn)(fn)
457
458greater_equal = _comparison_op(np.greater_equal, lax.ge)
459greater = _comparison_op(np.greater, lax.gt)
460less_equal = _comparison_op(np.less_equal, lax.le)
461less = _comparison_op(np.less, lax.lt)
462
463
464def _logical_op(np_op, bitwise_op):
465  @_wraps(np_op, update_doc=False)
466  def op(*args):
467    zero = lambda x: lax.full_like(x, shape=(), fill_value=0)
468    args = (x if issubdtype(_dtype(x), bool_) else lax.ne(x, zero(x))
469            for x in args)
470    return bitwise_op(*_promote_args(np_op.__name__, *args))
471  return op
472
473logical_and = _logical_op(np.logical_and, lax.bitwise_and)
474logical_not = _logical_op(np.logical_not, lax.bitwise_not)
475logical_or = _logical_op(np.logical_or, lax.bitwise_or)
476logical_xor = _logical_op(np.logical_xor, lax.bitwise_xor)
477
478
479@_wraps(np.right_shift)
480def right_shift(x1, x2):
481  x1, x2 = _promote_args(np.right_shift.__name__, x1, x2)
482  lax_fn = lax.shift_right_logical if \
483    np.issubdtype(x1.dtype, np.unsignedinteger) else lax.shift_right_arithmetic
484  return lax_fn(x1, x2)
485
486
487@_wraps(np.absolute)
488def absolute(x):
489  _check_arraylike('absolute', x)
490  dt = _dtype(x)
491  return x if dt == bool_ or issubdtype(dt, unsignedinteger) else lax.abs(x)
492abs = _wraps(np.abs)(absolute)
493
494
495@_wraps(np.rint)
496def rint(x):
497  _check_arraylike('rint', x)
498  dtype = _dtype(x)
499  if issubdtype(dtype, integer):
500    return lax.convert_element_type(x, float_)
501  if issubdtype(dtype, complexfloating):
502    return lax.complex(rint(lax.real(x)), rint(lax.imag(x)))
503  return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN)
504
505
506@_wraps(np.sign)
507def sign(x):
508  _check_arraylike('sign', x)
509  dtype = _dtype(x)
510  if issubdtype(dtype, complexfloating):
511    re = lax.real(x)
512    return lax.complex(
513      lax.sign(where(re != 0, re, lax.imag(x))), _constant_like(re, 0))
514  return lax.sign(x)
515
516
517@_wraps(np.copysign)
518def copysign(x1, x2):
519  x1, x2 = _promote_args_inexact("copysign", x1, x2)
520  if issubdtype(_dtype(x1), complexfloating):
521    raise TypeError("copysign does not support complex-valued inputs")
522  return where(signbit(x2), -lax.abs(x1), lax.abs(x1))
523
524
525@_wraps(np.true_divide)
526def true_divide(x1, x2):
527  x1, x2 = _promote_args_inexact("true_divide", x1, x2)
528  return lax.div(x1, x2)
529
530divide = true_divide
531
532@_wraps(np.floor_divide)
533def floor_divide(x1, x2):
534  x1, x2 = _promote_args("floor_divide", x1, x2)
535  dtype = _dtype(x1)
536  if issubdtype(dtype, integer):
537    quotient = lax.div(x1, x2)
538    select = logical_and(lax.sign(x1) != lax.sign(x2), lax.rem(x1, x2) != 0)
539    # TODO(mattjj): investigate why subtracting a scalar was causing promotion
540    return where(select, quotient - np.array(1, _dtype(quotient)), quotient)
541  elif issubdtype(dtype, complexfloating):
542    x1r = lax.real(x1)
543    x1i = lax.imag(x1)
544    x2r = lax.real(x2)
545    x2i = lax.imag(x2)
546    which = lax.ge(lax.abs(x2r), lax.abs(x2i))
547    rat1 = where(which, lax._const(x2i, 1), lax.div(x2r, x2i))
548    rat2 = where(which, lax.div(x2i, x2r), lax._const(x2i, 1))
549    out = lax.floor(lax.div(lax.add(lax.mul(x1r, rat1), lax.mul(x1i, rat2)),
550                            lax.add(lax.mul(x2r, rat1), lax.mul(x2i, rat2))))
551    return lax.convert_element_type(out, dtype)
552  else:
553    return _float_divmod(x1, x2)[0]
554
555
556@_wraps(np.divmod)
557def divmod(x1, x2):
558  x1, x2 = _promote_args("divmod", x1, x2)
559  if issubdtype(_dtype(x1), integer):
560    return floor_divide(x1, x2), remainder(x1, x2)
561  else:
562    return _float_divmod(x1, x2)
563
564
565def _float_divmod(x1, x2):
566  # see float_divmod in floatobject.c of CPython
567  mod = lax.rem(x1, x2)
568  div = lax.div(lax.sub(x1, mod), x2)
569
570  ind = lax.bitwise_and(mod != 0, lax.sign(x2) != lax.sign(mod))
571  mod = lax.select(ind, mod + x2, mod)
572  div = lax.select(ind, div - _constant_like(div, 1), div)
573
574  return lax.round(div), mod
575
576
577@_wraps(np.power)
578def power(x1, x2):
579  # Special case for concrete integer scalars: use binary exponentiation.
580  # Using lax.pow may be imprecise for floating-point values; the goal of this
581  # code path is to make sure we end up with a precise output for the common
582  # pattern ``x ** 2`` or similar.
583  try:
584    x2 = core.concrete_or_error(operator.index, x2)
585  except (core.ConcretizationTypeError, TypeError):
586    pass
587  else:
588    return lax.integer_pow(x1, x2)
589
590  x1, x2 = _promote_args("power", x1, x2)
591  dtype = _dtype(x1)
592  if not issubdtype(dtype, integer):
593    return lax.pow(x1, x2)
594
595  # Integer power => use binary exponentiation.
596
597  # TODO(phawkins): add integer pow support to XLA.
598  bits = 6  # Anything more would overflow for any x1 > 1
599  acc = ones(shape(x1), dtype=dtype)
600  for _ in range(bits):
601    acc = where(lax.bitwise_and(x2, _constant_like(x2, 1)),
602                lax.mul(acc, x1), acc)
603    x1 = lax.mul(x1, x1)
604    x2 = lax.shift_right_logical(x2, _constant_like(x2, 1))
605  return acc
606
607
608@custom_jvp
609@_wraps(np.logaddexp)
610def logaddexp(x1, x2):
611  x1, x2 = _promote_shapes("logaddexp", *_promote_dtypes_inexact(x1, x2))
612  amax = lax.max(x1, x2)
613  delta = lax.sub(x1, x2)
614  return lax.select(isnan(delta),
615                    lax.add(x1, x2),  # NaNs or infinities of the same sign.
616                    lax.add(amax, lax.log1p(lax.exp(-lax.abs(delta)))))
617
618@logaddexp.defjvp
619def _logaddexp_jvp(primals, tangents):
620  x1, x2 = primals
621  t1, t2 = tangents
622  x1, x2, t1, t2 = broadcast_arrays(x1, x2, t1, t2)
623  primal_out = logaddexp(x1, x2)
624  tangent_out = (t1 * exp(_replace_inf(x1) - _replace_inf(primal_out)) +
625                 t2 * exp(_replace_inf(x2) - _replace_inf(primal_out)))
626  return primal_out, tangent_out
627
628def _replace_inf(x):
629  return lax.select(isposinf(x), zeros_like(x), x)
630
631
632@custom_jvp
633@_wraps(np.logaddexp2)
634def logaddexp2(x1, x2):
635  x1, x2 = _promote_shapes("logaddexp2", *_promote_dtypes_inexact(x1, x2))
636  amax = lax.max(x1, x2)
637  delta = lax.sub(x1, x2)
638  return lax.select(isnan(delta),
639                    lax.add(x1, x2),  # NaNs or infinities of the same sign.
640                    lax.add(amax, lax.div(lax.log1p(exp2(-lax.abs(delta))),
641                                          _constant_like(x1, np.log(2)))))
642@logaddexp2.defjvp
643def _logaddexp2_jvp(primals, tangents):
644  x1, x2 = primals
645  t1, t2 = tangents
646  x1, x2, t1, t2 = broadcast_arrays(x1, x2, t1, t2)
647  primal_out = logaddexp2(x1, x2)
648  tangent_out = (t1 * 2 ** (_replace_inf(x1) - _replace_inf(primal_out)) +
649                 t2 * 2 ** (_replace_inf(x2) - _replace_inf(primal_out)))
650  return primal_out, tangent_out
651
652
653@_wraps(np.log2)
654def log2(x):
655  x, = _promote_dtypes_inexact(x)
656  return lax.div(lax.log(x), lax.log(_constant_like(x, 2)))
657
658
659@_wraps(np.log10)
660def log10(x):
661  x, = _promote_dtypes_inexact(x)
662  return lax.div(lax.log(x), lax.log(_constant_like(x, 10)))
663
664
665@_wraps(np.exp2)
666def exp2(x):
667  x, = _promote_dtypes_inexact(x)
668  return lax.exp(lax.mul(lax.log(_constant_like(x, 2)), x))
669
670@_wraps(np.signbit)
671def signbit(x):
672  x, = _promote_shapes("signbit", x)
673  dtype = _dtype(x)
674  if issubdtype(dtype, integer):
675    return lax.lt(x, _constant_like(x, 0))
676  elif issubdtype(dtype, bool_):
677    return full_like(x, False, dtype=bool_)
678  elif not issubdtype(dtype, floating):
679    raise ValueError(
680        "jax.numpy.signbit is not well defined for %s" % dtype)
681
682  # TPU supports BF16 but not S16 types, so as a workaround, convert BF16 to
683  # F32.
684  if dtype == bfloat16:
685    dtype = float32
686    x = lax.convert_element_type(x, float32)
687
688  info = finfo(dtype)
689  if info.bits not in _INT_DTYPES:
690    raise NotImplementedError(
691        "jax.numpy.signbit only supports 16, 32, and 64-bit types.")
692  int_type = _INT_DTYPES[info.bits]
693  x = lax.bitcast_convert_type(x, int_type)
694  return lax.convert_element_type(x >> (info.nexp + info.nmant), np.bool_)
695
696
697@_wraps(np.trapz)
698def trapz(y, x=None, dx=1.0, axis: int = -1):
699  _check_arraylike('trapz', y)
700  y = moveaxis(y, axis, -1)
701  if x is not None:
702    if ndim(x) == 1:
703      dx = diff(x)
704    else:
705      dx = moveaxis(diff(x, axis=axis), axis, -1)
706  return 0.5 * (dx * (y[..., 1:] + y[..., :-1])).sum(-1)
707
708
709@_wraps(np.trunc)
710def trunc(x):
711  _check_arraylike('trunc', x)
712  return where(lax.lt(x, lax._const(x, 0)), ceil(x), floor(x))
713
714
715def _conv(x, y, mode, op, precision):
716  if issubdtype(_dtype(x), complexfloating) or issubdtype(_dtype(y), complexfloating):
717    raise NotImplementedError(f"{op}() does not support complex inputs")
718  if ndim(x) != 1 or ndim(y) != 1:
719    raise ValueError(f"{op}() only support 1-dimensional inputs.")
720  x, y = _promote_dtypes_inexact(x, y)
721  if len(x) == 0 or len(y) == 0:
722    raise ValueError(f"{op}: inputs cannot be empty, got shapes {x.shape} and {y.shape}.")
723
724  out_order = slice(None)
725  if len(x) < len(y):
726    x, y = y, x
727    if op == "correlate":
728      out_order = slice(None, None, -1)
729  if op == 'convolve':
730    y = y[::-1]
731
732  if mode == 'valid':
733    padding = [(0, 0)]
734  elif mode == 'same':
735    padding = [(y.shape[0] // 2, y.shape[0] - y.shape[0] // 2 - 1)]
736  elif mode == 'full':
737    padding = [(y.shape[0] - 1, y.shape[0] - 1)]
738  else:
739    raise ValueError("mode must be one of ['full', 'same', 'valid']")
740
741  result = lax.conv_general_dilated(x[None, None, :], y[None, None, :], (1,),
742                                    padding, precision=precision)
743  return result[0, 0, out_order]
744
745
746@_wraps(np.convolve, lax_description=_PRECISION_DOC)
747def convolve(a, v, mode='full', *, precision=None):
748  _check_arraylike("convolve", a, v)
749  return _conv(a, v, mode, 'convolve', precision)
750
751
752@_wraps(np.correlate, lax_description=_PRECISION_DOC)
753def correlate(a, v, mode='valid', *, precision=None):
754  _check_arraylike("correlate", a, v)
755  return _conv(a, v, mode, 'correlate', precision)
756
757
758def _normalize_float(x):
759  info = finfo(_dtype(x))
760  cond = lax.abs(x) < info.tiny
761  x1 = where(cond, x * lax._const(x, 1 << info.nmant), x)
762  x2 = where(cond, lax._const(np.int32, -info.nmant), lax._const(np.int32, 0))
763  int_type = _INT_DTYPES[info.bits]
764  return lax.bitcast_convert_type(x1, int_type), x2
765
766
767@_wraps(np.ldexp)
768@jit
769def ldexp(x1, x2):
770  dtype = dtypes.canonicalize_dtype(_result_dtype(np.ldexp, x1, x2))
771  x1, x2 = _promote_shapes("ldexp", x1, x2)
772  x1 = lax.convert_element_type(x1, dtype)
773
774  info = finfo(dtype)
775  mask = (1 << info.nexp) - 1
776  bias = ((1 << info.nexp) - 1) >> 1
777
778  int_type = _INT_DTYPES[info.bits]
779
780  x, e = _normalize_float(x1)
781  x2 += e + ((x >> info.nmant) & mask) - bias
782
783  # find underflow/overflow before denormalization
784  underflow_cond = x2 < -(bias + info.nmant)
785  overflow_cond = x2 > bias
786
787  m = ones_like(x, dtype=dtype)
788
789  # denormals
790  cond = x2 < -bias + 1
791  x2 = where(cond, x2 + info.nmant, x2)
792  m = where(cond, m / (1 << info.nmant), m)
793
794  x2 = lax.convert_element_type(x2, np.int32)
795  x &= ~(mask << info.nmant)
796  x |= ((lax.convert_element_type(x2, int_type) + bias) << info.nmant)
797
798  x = lax.convert_element_type(m, dtype) * lax.bitcast_convert_type(x, dtype)
799
800  # underflow
801  x = where(underflow_cond, zeros_like(x, dtype=dtype), x)
802  # overflow
803  x = where(overflow_cond, lax.sign(x1) * full_like(x, np.inf), x)
804  # ldexp(x1, x2) = x1 for x1 = inf, -inf, nan, 0
805  return where(isinf(x1) | isnan(x1) | (x1 == 0), x1, x)
806
807
808@_wraps(np.frexp)
809@jit
810def frexp(x):
811  x = asarray(x)
812  if issubdtype(x.dtype, complexfloating):
813    raise TypeError("frexp does not support complex-valued inputs")
814  elif not issubdtype(x.dtype, floating):
815    x = lax.convert_element_type(x, float_)
816
817  dtype = _dtype(x)
818  info = finfo(dtype)
819  mask = (1 << info.nexp) - 1
820  bias = ((1 << info.nexp) - 1) >> 1
821
822  x1, x2 = _normalize_float(x)
823  x2 += ((x1 >> info.nmant) & mask) - bias + 1
824  x1 &= ~(mask << info.nmant)
825  x1 |= (bias - 1) << info.nmant
826  x1 = lax.bitcast_convert_type(x1, dtype)
827
828  cond = isinf(x) | isnan(x) | (x == 0)
829  x2 = where(cond, zeros_like(x2), x2)
830  return where(cond, x, x1), lax.convert_element_type(x2, int32)
831
832
833@_wraps(np.remainder)
834def remainder(x1, x2):
835  x1, x2 = _promote_args("remainder", x1, x2)
836  zero = _constant_like(x1, 0)
837  trunc_mod = lax.rem(x1, x2)
838  trunc_mod_not_zero = lax.ne(trunc_mod, zero)
839  do_plus = lax.bitwise_and(
840      lax.ne(lax.lt(trunc_mod, zero), lax.lt(x2, zero)), trunc_mod_not_zero)
841  return lax.select(do_plus, lax.add(trunc_mod, x2), trunc_mod)
842mod = _wraps(np.mod)(remainder)
843
844
845@_wraps(np.fmod)
846def fmod(x1, x2):
847  _check_arraylike("fmod", x1, x2)
848  if issubdtype(_dtype(x1, x2), integer):
849    x2 = where(x2 == 0, 1, x2)
850  return lax.rem(*_promote_args("fmod", x1, x2))
851
852
853@_wraps(np.cbrt)
854def cbrt(x):
855  _check_arraylike("cbrt", x)
856  x, = _promote_dtypes_inexact(x)
857  return lax.sign(x) * power(lax.abs(x), _constant_like(x, 1. / 3.))
858
859
860@_wraps(np.square)
861def square(x):
862  _check_arraylike("square", x)
863  return lax.integer_pow(x, 2)
864
865
866@_wraps(np.deg2rad)
867def deg2rad(x):
868  _check_arraylike("deg2rad", x)
869  x, = _promote_dtypes_inexact(x)
870  return lax.mul(x, lax._const(x, pi / 180))
871
872
873@_wraps(np.rad2deg)
874def rad2deg(x):
875  _check_arraylike("rad2deg", x)
876  x, = _promote_dtypes_inexact(x)
877  return lax.mul(x, lax._const(x, 180 / pi))
878
879
880degrees = rad2deg
881radians = deg2rad
882
883
884@_wraps(np.histogram_bin_edges)
885def histogram_bin_edges(a, bins=10, range=None, weights=None):
886  if isinstance(bins, str):
887    raise NotImplementedError("string values for `bins` not implemented.")
888  a = ravel(a)
889  b = asarray(bins)
890  if b.ndim == 1:
891    return b
892  if range is None:
893    range = (a.min(), a.max())
894  assert len(range) == 2
895  range = asarray(range)
896  range = (where(ptp(range) == 0, range[0] - 0.5, range[0]),
897           where(ptp(range) == 0, range[1] + 0.5, range[1]))
898  dtype = _dtype(a)
899  if issubdtype(dtype, integer):
900    dtype = promote_types(dtype, float32)
901  return linspace(range[0], range[1], bins + 1, dtype=dtype)
902
903
904@_wraps(np.histogram)
905def histogram(a, bins=10, range=None, weights=None, density=None):
906  if weights is not None and a.shape != weights.shape:
907    raise ValueError("weights should have the same shape as a.")
908  a = ravel(a)
909  if weights is not None:
910    weights = ravel(weights)
911  else:
912    weights = ones_like(a)
913  bin_edges = histogram_bin_edges(a, bins, range, weights)
914  bin_idx = searchsorted(bin_edges, a, side='right')
915  bin_idx = where(a == bin_edges[-1], len(bin_edges) - 1, bin_idx)
916  counts = bincount(bin_idx, weights, length=len(bin_edges))[1:]
917  if density:
918    bin_widths = diff(bin_edges)
919    counts = counts / bin_widths / counts.sum()
920  return counts, bin_edges
921
922@_wraps(np.histogram2d)
923def histogram2d(x, y, bins=10, range=None, weights=None, density=None):
924
925  try:
926    N = len(bins)
927  except TypeError:
928    N = 1
929
930  if N != 1 and N != 2:
931    x_edges = y_edges = asarray(bins)
932    bins = [x_edges, y_edges]
933
934  sample = transpose(asarray([x, y]))
935  hist, edges = histogramdd(sample, bins, range, weights, density)
936  return hist, edges[0], edges[1]
937
938@_wraps(np.histogramdd)
939def histogramdd(sample, bins=10, range=None, weights=None, density=None):
940  _check_arraylike("histogramdd", sample)
941  N, D = shape(sample)
942
943  if weights is not None and weights.shape != (N,):
944    raise ValueError("should have one weight for each sample.")
945
946  try:
947    num_bins = len(bins)
948    if num_bins != D:
949      raise ValueError("should be a bin for each dimension.")
950  except TypeError:
951    # when bin_size is integer, the same bin is used for each dimension
952    bins = D * [bins]
953
954  bin_idx_by_dim = D*[None]
955  nbins = np.empty(D, int)
956  bin_edges_by_dim = D*[None]
957  dedges = D*[None]
958
959  for i in builtins.range(D):
960    bin_edges = histogram_bin_edges(sample[:, i], bins[i], range, weights)
961    bin_idx = searchsorted(bin_edges, sample[:, i], side='right')
962    bin_idx = where(sample[:, i] == bin_edges[-1], bin_idx - 1, bin_idx)
963    bin_idx_by_dim[i] = bin_idx
964    nbins[i] = len(bin_edges) + 1
965    bin_edges_by_dim[i] = bin_edges
966    dedges[i] = diff(bin_edges_by_dim[i])
967
968  xy = ravel_multi_index(bin_idx_by_dim, nbins, mode='clip')
969  hist = bincount(xy, weights, length=nbins.prod())
970  hist = reshape(hist, nbins)
971  core = D*(slice(1, -1),)
972  hist = hist[core]
973
974  if density:
975    s = sum(hist)
976    for i in builtins.range(D):
977      _shape = np.ones(D, int)
978      _shape[i] = nbins[i] - 2
979      hist = hist / reshape(dedges[i], _shape)
980
981    hist /= s
982
983  return hist, bin_edges_by_dim
984
985@_wraps(np.heaviside)
986def heaviside(x1, x2):
987  _check_arraylike("heaviside", x1, x2)
988  x1, x2 = _promote_dtypes_inexact(x1, x2)
989  zero = lax._const(x1, 0)
990  return where(lax.lt(x1, zero), zero,
991               where(lax.gt(x1, zero), lax._const(x1, 1), x2))
992
993
994@_wraps(np.hypot)
995def hypot(x1, x2):
996  _check_arraylike("hypot", x1, x2)
997  x1, x2 = _promote_dtypes_inexact(x1, x2)
998  x1 = lax.abs(x1)
999  x2 = lax.abs(x2)
1000  x1, x2 = maximum(x1, x2), minimum(x1, x2)
1001  return lax.select(x1 == 0, x1, x1 * lax.sqrt(1 + lax.square(lax.div(x2, lax.select(x1 == 0, ones_like(x1), x1)))))
1002
1003
1004@_wraps(np.reciprocal)
1005def reciprocal(x):
1006  _check_arraylike("reciprocal", x)
1007  x, = _promote_dtypes_inexact(x)
1008  return lax.integer_pow(x, -1)
1009
1010
1011@_wraps(np.sinc, update_doc=False)
1012def sinc(x):
1013  _check_arraylike("sinc", x)
1014  x, = _promote_dtypes_inexact(x)
1015  eq_zero = lax.eq(x, lax._const(x, 0))
1016  pi_x = lax.mul(lax._const(x, pi), x)
1017  safe_pi_x = where(eq_zero, lax._const(x, 0), pi_x)
1018  return where(eq_zero, _sinc_maclaurin(0, pi_x),
1019               lax.div(lax.sin(safe_pi_x), safe_pi_x))
1020
1021@partial(custom_jvp, nondiff_argnums=(0,))
1022def _sinc_maclaurin(k, x):
1023  # compute the kth derivative of x -> sin(x)/x evaluated at zero (since we
1024  # compute the monomial term in the jvp rule)
1025  if k % 2:
1026    return lax.full_like(x, 0)
1027  else:
1028    return lax.full_like(x, (-1) ** (k // 2) / (k + 1))
1029
1030@_sinc_maclaurin.defjvp
1031def _sinc_maclaurin_jvp(k, primals, tangents):
1032  (x,), (t,) = primals, tangents
1033  return _sinc_maclaurin(k, x), _sinc_maclaurin(k + 1, x) * t
1034
1035
1036@_wraps(np.transpose)
1037def transpose(a, axes=None):
1038  _check_arraylike("transpose", a)
1039  axes = np.arange(ndim(a))[::-1] if axes is None else axes
1040  return lax.transpose(a, axes)
1041
1042
1043@_wraps(np.rot90)
1044def rot90(m, k=1, axes=(0, 1)):
1045  _check_arraylike("rot90", m)
1046  ax1, ax2 = axes
1047  ax1 = _canonicalize_axis(ax1, ndim(m))
1048  ax2 = _canonicalize_axis(ax2, ndim(m))
1049  if ax1 == ax2:
1050    raise ValueError("Axes must be different")  # same as numpy error
1051  k = k % 4
1052  if k == 0:
1053    return m
1054  elif k == 2:
1055    return flip(flip(m, ax1), ax2)
1056  else:
1057    perm = list(range(m.ndim))
1058    perm[ax1], perm[ax2] = perm[ax2], perm[ax1]
1059    if k == 1:
1060      return transpose(flip(m, ax2), perm)
1061    else:
1062      return flip(transpose(m, perm), ax2)
1063
1064
1065@_wraps(np.flip)
1066def flip(m, axis: Optional[Union[int, Tuple[int, ...]]] = None):
1067  _check_arraylike("flip", m)
1068  if axis is None:
1069    return lax.rev(m, list(range(len(shape(m)))))
1070  return lax.rev(m, [_canonicalize_axis(axis, ndim(m))])
1071
1072
1073@_wraps(np.fliplr)
1074def fliplr(m):
1075  return flip(m, 1)
1076
1077
1078@_wraps(np.flipud)
1079def flipud(m):
1080  return flip(m, 0)
1081
1082
1083@_wraps(np.conjugate)
1084def conjugate(x):
1085  _check_arraylike("conjugate", x)
1086  return lax.conj(x) if iscomplexobj(x) else x
1087conj = conjugate
1088
1089
1090@_wraps(np.imag)
1091def imag(val):
1092  _check_arraylike("imag", val)
1093  return lax.imag(val) if iscomplexobj(val) else zeros_like(val)
1094
1095
1096@_wraps(np.real)
1097def real(val):
1098  _check_arraylike("real", val)
1099  return lax.real(val) if iscomplexobj(val) else val
1100
1101
1102@_wraps(np.iscomplex)
1103def iscomplex(x):
1104  i = imag(x)
1105  return lax.ne(i, lax._const(i, 0))
1106
1107@_wraps(np.isreal)
1108def isreal(x):
1109  i = imag(x)
1110  return lax.eq(i, lax._const(i, 0))
1111
1112@_wraps(np.angle)
1113def angle(z):
1114  re = real(z)
1115  im = imag(z)
1116  dtype = _dtype(re)
1117  if not issubdtype(dtype, inexact) or (
1118      issubdtype(_dtype(z), floating) and ndim(z) == 0):
1119    dtype = dtypes.canonicalize_dtype(float_)
1120    re = lax.convert_element_type(re, dtype)
1121    im = lax.convert_element_type(im, dtype)
1122  return lax.atan2(im, re)
1123
1124
1125@_wraps(np.diff)
1126def diff(a, n=1, axis: int = -1, prepend=None, append=None):
1127  _check_arraylike("diff", a)
1128  if n == 0:
1129    return a
1130  if n < 0:
1131    raise ValueError(f"order must be non-negative but got {n}")
1132  if ndim(a) == 0:
1133    raise ValueError(f"diff requires input that is at least one dimensional; got {a}")
1134
1135  nd = a.ndim
1136
1137  combined = []
1138  if prepend is not None:
1139    _check_arraylike("diff", prepend)
1140    if isscalar(prepend):
1141      shape = list(a.shape)
1142      shape[axis] = 1
1143      prepend = broadcast_to(prepend, tuple(shape))
1144    combined.append(prepend)
1145
1146  combined.append(a)
1147
1148  if append is not None:
1149    _check_arraylike("diff", append)
1150    if isscalar(append):
1151      shape = list(a.shape)
1152      shape[axis] = 1
1153      append = broadcast_to(append, tuple(shape))
1154    combined.append(append)
1155
1156  if len(combined) > 1:
1157    a = concatenate(combined, axis)
1158
1159  slice1 = [slice(None)] * nd
1160  slice2 = [slice(None)] * nd
1161  slice1[axis] = slice(1, None)
1162  slice2[axis] = slice(None, -1)
1163  slice1_tuple = tuple(slice1)
1164  slice2_tuple = tuple(slice2)
1165
1166  op = not_equal if a.dtype == np.bool_ else subtract
1167  for _ in range(n):
1168    a = op(a[slice1_tuple], a[slice2_tuple])
1169
1170  return a
1171
1172_EDIFF1D_DOC = """\
1173Unlike NumPy's implementation of ediff1d, :py:func:`jax.numpy.ediff1d` will not
1174issue an error if casting ``to_end`` or ``to_begin`` to the type of ``ary``
1175loses precision.
1176"""
1177
1178@_wraps(np.ediff1d, lax_description=_EDIFF1D_DOC)
1179def ediff1d(ary, to_end=None, to_begin=None):
1180  ary = ravel(asarray(ary))
1181  result = lax.sub(ary[1:], ary[:-1])
1182  if to_begin is not None:
1183    result = concatenate((ravel(asarray(to_begin, dtype=ary.dtype)), result))
1184  if to_end is not None:
1185    result = concatenate((result, ravel(asarray(to_end, dtype=ary.dtype))))
1186  return result
1187
1188
1189@partial(jit, static_argnums=2)
1190def _gradient(a, varargs, axis):
1191  def gradient_along_axis(a, h, axis):
1192    sliced = partial(lax.slice_in_dim, a, axis=axis)
1193    a_grad = concatenate((
1194      (sliced(1, 2) - sliced(0, 1)),  # upper edge
1195      (sliced(2, None) - sliced(None, -2)) * 0.5,  # inner
1196      (sliced(-1, None) - sliced(-2, -1)),  # lower edge
1197    ), axis)
1198    return a_grad / h
1199
1200  if axis is None:
1201    axis = range(a.ndim)
1202  else:
1203    if isinstance(axis, int):
1204      axis = (axis,)
1205    if not isinstance(axis, tuple) and not isinstance(axis, list):
1206      raise ValueError("Give `axis` either as int or iterable")
1207    elif len(axis) == 0:
1208      return []
1209    axis = [_canonicalize_axis(i, a.ndim) for i in axis]
1210
1211  if _min([s for i, s in enumerate(a.shape) if i in axis]) < 2:
1212    raise ValueError("Shape of array too small to calculate "
1213                     "a numerical gradient, "
1214                     "at least 2 elements are required.")
1215  len_axes = len(axis)
1216  n = len(varargs)
1217  if n == 0 or varargs is None:
1218    # no spacing
1219    dx = [1.0] * len_axes
1220  elif n == 1:
1221    # single value for all axes
1222    dx = varargs * len_axes
1223  elif n == len_axes:
1224    dx = varargs
1225  else:
1226    TypeError("Invalid number of spacing arguments %d" % n)
1227
1228  if ndim(dx[0]) != 0:
1229    raise NotImplementedError("Non-constant spacing not implemented")
1230
1231  # TODO: use jax.lax loop tools if possible
1232  a_grad = [gradient_along_axis(a, h, ax) for ax, h in zip(axis, dx)]
1233
1234  if len(axis) == 1:
1235    a_grad = a_grad[0]
1236
1237  return a_grad
1238
1239
1240@_wraps(np.gradient)
1241def gradient(f, *varargs, axis: Optional[Union[int, Tuple[int, ...]]] = None,
1242             edge_order=None):
1243  if edge_order is not None:
1244    raise NotImplementedError("The 'edge_order' argument to jnp.gradient is not supported.")
1245  return _gradient(f, varargs, axis)
1246
1247
1248@_wraps(np.isrealobj)
1249def isrealobj(x):
1250  return not iscomplexobj(x)
1251
1252
1253@_wraps(np.reshape)
1254def reshape(a, newshape, order="C"):
1255  try:
1256    return a.reshape(newshape, order=order)  # forward to method for ndarrays
1257  except AttributeError:
1258    return _reshape(a, newshape, order=order)
1259
1260def _compute_newshape(a, newshape):
1261  """Fixes a -1 value in newshape, if present."""
1262  # other errors, like having more than one -1, are caught downstream
1263  try: iter(newshape)
1264  except: iterable = False
1265  else: iterable = True
1266  def check(size):
1267    return size if type(size) is Poly else core.concrete_or_error(
1268      operator.index, size, "The error arose in jax.numpy.reshape.")
1269  newshape = [check(size) for size in newshape] if iterable else [check(newshape)]
1270  if np.any(np.equal(newshape, -1)):
1271    fix = -a.size // (newshape if type(newshape) is Poly else _prod(newshape))
1272    return [d if d != -1 else fix for d in newshape]
1273  else:
1274    return newshape
1275
1276def _reshape(a, *args, order="C"):
1277  newshape = _compute_newshape(a, args[0] if len(args) == 1 else args)
1278  if order == "C":
1279    return lax.reshape(a, newshape, None)
1280  elif order == "F":
1281    dims = np.arange(ndim(a))[::-1]
1282    return lax.reshape(a, newshape[::-1], dims).T
1283  elif order == "A":
1284    raise NotImplementedError("np.reshape order=A is not implemented.")
1285  else:
1286    raise ValueError("Unexpected value for 'order' argument: {}.".format(order))
1287
1288@_wraps(np.ravel)
1289def ravel(a, order="C"):
1290  if order == "K":
1291    raise NotImplementedError("Ravel not implemented for order='K'.")
1292  return reshape(a, (size(a),), order)
1293
1294
1295@_wraps(np.ravel_multi_index)
1296def ravel_multi_index(multi_index, dims, mode='raise', order='C'):
1297  assert len(multi_index) == len(dims), f"len(multi_index)={len(multi_index)} != len(dims)={len(dims)}"
1298  dims = tuple(core.concrete_or_error(int, d, "in `dims` argument of ravel_multi_index().") for d in dims)
1299  _check_arraylike("ravel_multi_index", *multi_index)
1300  for index in multi_index:
1301    if mode == 'raise':
1302      core.concrete_or_error(array, index,
1303        "The error occurred because ravel_multi_index was jit-compiled"
1304        " with mode='raise'. Use mode='wrap' or mode='clip' instead.")
1305    if not issubdtype(_dtype(index), integer):
1306      raise TypeError("only int indices permitted")
1307  if mode == "raise":
1308    if _any(any((i < 0) | (i >= d)) for i, d in zip(multi_index, dims)):
1309      raise ValueError("invalid entry in coordinates array")
1310  elif mode == "clip":
1311    multi_index = [clip(i, 0, d - 1) for i, d in zip(multi_index, dims)]
1312  elif mode == "wrap":
1313    multi_index = [i % d for i, d in zip(multi_index, dims)]
1314  else:
1315    raise ValueError(f"invalid mode={mode!r}. Expected 'raise', 'wrap', or 'clip'")
1316
1317  if order == "F":
1318    strides = np.cumprod((1,) + dims[:-1])
1319  elif order == "C":
1320    strides = np.cumprod((1,) + dims[1:][::-1])[::-1]
1321  else:
1322    raise ValueError(f"invalid order={order!r}. Expected 'C' or 'F'")
1323
1324  result = 0
1325  for i, s in zip(multi_index, strides):
1326    result = result + i * s
1327  return result
1328
1329
1330_UNRAVEL_INDEX_DOC = """\
1331Unlike numpy's implementation of unravel_index, negative indices are accepted
1332and out-of-bounds indices are clipped.
1333"""
1334
1335@_wraps(np.unravel_index, lax_description=_UNRAVEL_INDEX_DOC)
1336def unravel_index(indices, shape):
1337  indices = asarray(indices)
1338  sizes = pad(shape, (0, 1), constant_values=1)
1339  cumulative_sizes = cumprod(sizes[::-1])[::-1]
1340  total_size = cumulative_sizes[0]
1341  # Clip so raveling and unraveling an oob index will not change the behavior
1342  clipped_indices = clip(indices, -total_size, total_size - 1)
1343  # Add enough trailing dims to avoid conflict with flat_index
1344  cumulative_sizes = cumulative_sizes.reshape([-1] + [1] * indices.ndim)
1345  idx = clipped_indices % cumulative_sizes[:-1] // cumulative_sizes[1:]
1346  return tuple(idx)
1347
1348
1349@_wraps(np.squeeze)
1350def squeeze(a, axis: Optional[Union[int, Tuple[int, ...]]] = None):
1351  _check_arraylike("squeeze", a)
1352  if axis is None:
1353    a_shape = shape(a)
1354    axis = tuple(i for i, d in enumerate(a_shape) if d == 1)
1355  elif not isinstance(axis, tuple):
1356    axis = (axis,)
1357  return lax.squeeze(a, axis)
1358
1359
1360@_wraps(np.expand_dims)
1361def expand_dims(a, axis: Union[int, Tuple[int, ...]]):
1362  _check_arraylike("expand_dims", a)
1363  if not isinstance(axis, tuple):
1364    axis = (axis,)
1365  return lax.expand_dims(a, axis)
1366
1367
1368@_wraps(np.swapaxes)
1369def swapaxes(a, axis1: int, axis2: int):
1370  _check_arraylike("swapaxes", a)
1371  perm = np.arange(ndim(a))
1372  perm[axis1], perm[axis2] = perm[axis2], perm[axis1]
1373  return lax.transpose(a, perm)
1374
1375
1376@_wraps(np.moveaxis)
1377def moveaxis(a, source: Union[int, Sequence[int]],
1378             destination: Union[int, Sequence[int]]):
1379  _check_arraylike("moveaxis", a)
1380  source_axes: Tuple[int, ...]
1381  destination_axes: Tuple[int, ...]
1382  try:
1383    source_axes = (operator.index(source),)
1384  except TypeError:
1385    source_axes = tuple(cast(Sequence[int], source))
1386  try:
1387    destination_axes = (operator.index(destination),)
1388  except TypeError:
1389    destination_axes = tuple(cast(Sequence[int], destination))
1390  source_axes = tuple(_canonicalize_axis(i, ndim(a)) for i in source_axes)
1391  destination_axes = tuple(_canonicalize_axis(i, ndim(a))
1392                           for i in destination_axes)
1393  if len(source_axes) != len(destination_axes):
1394    raise ValueError("Inconsistent number of elements: {} vs {}"
1395                     .format(len(source_axes), len(destination_axes)))
1396  perm = [i for i in range(ndim(a)) if i not in source_axes]
1397  for dest, src in sorted(zip(destination_axes, source_axes)):
1398    perm.insert(dest, src)
1399  return lax.transpose(a, perm)
1400
1401
1402@_wraps(np.isclose)
1403def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
1404  a, b = _promote_args("isclose", asarray(a), asarray(b))
1405  dtype = _dtype(a)
1406  if issubdtype(dtype, inexact):
1407    if issubdtype(dtype, complexfloating):
1408      dtype = _complex_elem_type(dtype)
1409    rtol = lax.convert_element_type(rtol, dtype)
1410    atol = lax.convert_element_type(atol, dtype)
1411    out = lax.le(
1412      lax.abs(lax.sub(a, b)),
1413      lax.add(atol, lax.mul(rtol, lax.abs(b))))
1414    # This corrects the comparisons for infinite and nan values
1415    a_inf = isinf(a)
1416    b_inf = isinf(b)
1417    any_inf = logical_or(a_inf, b_inf)
1418    both_inf = logical_and(a_inf, b_inf)
1419    # Make all elements where either a or b are infinite to False
1420    out = logical_and(out, logical_not(any_inf))
1421    # Make all elements where both a or b are the same inf to True
1422    same_value = lax.eq(a, b)
1423    same_inf = logical_and(both_inf, same_value)
1424    out = logical_or(out, same_inf)
1425
1426    # Make all elements where either a or b is NaN to False
1427    a_nan = isnan(a)
1428    b_nan = isnan(b)
1429    any_nan = logical_or(a_nan, b_nan)
1430    out = logical_and(out, logical_not(any_nan))
1431    if equal_nan:
1432      # Make all elements where both a and b is NaN to True
1433      both_nan = logical_and(a_nan, b_nan)
1434      out = logical_or(out, both_nan)
1435    return _maybe_numpy_1_13_isclose_behavior(a, out)
1436  else:
1437    return lax.eq(a, b)
1438
1439numpy_version = tuple(map(int, np.version.version.split('.')[:2]))
1440if numpy_version < (1, 14):
1441  # see discussion at https://github.com/numpy/numpy/pull/9720
1442  def _maybe_numpy_1_13_isclose_behavior(a, out):
1443    if size(out) == 1 and issubdtype(_dtype(a), complexfloating):
1444      return lax.reshape(out, (1,))
1445    else:
1446      return out
1447else:
1448  def _maybe_numpy_1_13_isclose_behavior(a, out):
1449    return out
1450
1451@_wraps(np.interp)
1452def interp(x, xp, fp, left=None, right=None, period=None):
1453  if shape(xp) != shape(fp) or ndim(xp) != 1:
1454    raise ValueError("xp and fp must be one-dimensional arrays of equal size")
1455  x, xp, fp = map(asarray, _promote_dtypes_inexact(x, xp, fp))
1456  if period is not None:
1457    if period == 0:
1458      raise ValueError(f"period must be a non-zero value; got {period}")
1459    period = abs(period)
1460    x = x % period
1461    xp = xp % period
1462    xp, fp = lax.sort_key_val(xp, fp)
1463    xp = concatenate([xp[-1:] - period, xp, xp[:1] + period])
1464    fp = concatenate([fp[-1:], fp, fp[:1]])
1465
1466  i = clip(searchsorted(xp, x, side='right'), 1, len(xp) - 1)
1467  df = fp[i] - fp[i - 1]
1468  dx = xp[i] - xp[i - 1]
1469  delta = x - xp[i - 1]
1470  f = where((dx == 0), fp[i], fp[i - 1] + (delta / dx) * df)
1471
1472  if period is None:
1473    f = where(x < xp[0], fp[0] if left is None else left, f)
1474    f = where(x > xp[-1], fp[-1] if right is None else right, f)
1475  return f
1476
1477
1478@_wraps(np.in1d, lax_description="""
1479In the JAX version, the `assume_unique` argument is not referenced.
1480""")
1481def in1d(ar1, ar2, assume_unique=False, invert=False):
1482  ar1 = ravel(ar1)
1483  ar2 = ravel(ar2)
1484  # Note: an algorithm based on searchsorted has better scaling, but in practice
1485  # is very slow on accelerators because it relies on lax control flow. If XLA
1486  # ever supports binary search natively, we should switch to this:
1487  #   ar2 = jnp.sort(ar2)
1488  #   ind = jnp.searchsorted(ar2, ar1)
1489  #   if invert:
1490  #     return ar1 != ar2[ind]
1491  #   else:
1492  #     return ar1 == ar2[ind]
1493  if invert:
1494    return (ar1[:, None] != ar2).all(-1)
1495  else:
1496    return (ar1[:, None] == ar2).any(-1)
1497
1498@_wraps(np.setdiff1d, lax_description="""
1499In the JAX version, the `assume_unique` argument is not referenced.
1500""")
1501def setdiff1d(ar1, ar2, assume_unique=False):
1502  ar1 = core.concrete_or_error(asarray, ar1, "The error arose in setdiff1d()")
1503  ar2 = core.concrete_or_error(asarray, ar2, "The error arose in setdiff1d()")
1504
1505  ar1 = unique(ar1)
1506  ar2 = unique(ar2)
1507
1508  idx = in1d(ar1, ar2, invert=True)
1509  return ar1[idx]
1510
1511@partial(jit, static_argnums=2)
1512def _intersect1d_sorted_mask(ar1, ar2, return_indices=False):
1513  """
1514    Helper function for intersect1d which is jit-able
1515    """
1516  ar = concatenate((ar1, ar2))
1517  if return_indices:
1518    iota = lax.broadcasted_iota(np.int64, shape(ar), dimension=0)
1519    aux, indices = lax.sort_key_val(ar, iota)
1520  else:
1521    aux = sort(ar)
1522
1523  mask = aux[1:] == aux[:-1]
1524  if return_indices:
1525    return aux, mask, indices
1526  else:
1527    return aux, mask
1528
1529
1530@_wraps(np.intersect1d)
1531def intersect1d(ar1, ar2, assume_unique=False, return_indices=False):
1532  ar1 = core.concrete_or_error(asarray, ar1, "The error arose in intersect1d()")
1533  ar2 = core.concrete_or_error(asarray, ar2, "The error arose in intersect1d()")
1534
1535  if not assume_unique:
1536    if return_indices:
1537      ar1, ind1 = unique(ar1, return_index=True)
1538      ar2, ind2 = unique(ar2, return_index=True)
1539    else:
1540      ar1 = unique(ar1)
1541      ar2 = unique(ar2)
1542  else:
1543    ar1 = ravel(ar1)
1544    ar2 = ravel(ar2)
1545
1546  if return_indices:
1547    aux, mask, aux_sort_indices = _intersect1d_sorted_mask(ar1, ar2, return_indices)
1548  else:
1549    aux, mask = _intersect1d_sorted_mask(ar1, ar2, return_indices)
1550
1551  int1d = aux[:-1][mask]
1552
1553  if return_indices:
1554    ar1_indices = aux_sort_indices[:-1][mask]
1555    ar2_indices = aux_sort_indices[1:][mask] - ar1.size
1556    if not assume_unique:
1557      ar1_indices = ind1[ar1_indices]
1558      ar2_indices = ind2[ar2_indices]
1559
1560    return int1d, ar1_indices, ar2_indices
1561  else:
1562    return int1d
1563
1564
1565@_wraps(np.isin, lax_description="""
1566In the JAX version, the `assume_unique` argument is not referenced.
1567""")
1568def isin(element, test_elements, assume_unique=False, invert=False):
1569  result = in1d(element, test_elements, assume_unique=assume_unique, invert=invert)
1570  return result.reshape(shape(element))
1571
1572
1573# The `jit` on `where` exists to avoid materializing constants in cases like
1574# `np.where(np.zeros(1000), 7, 4)`. In op-by-op mode, we don't want to
1575# materialize the broadcast forms of scalar arguments.
1576@jit
1577def _where(condition, x=None, y=None):
1578  if x is None or y is None:
1579    raise ValueError("Either both or neither of the x and y arguments should "
1580                     "be provided to jax.numpy.where, got {} and {}."
1581                     .format(x, y))
1582  if not issubdtype(_dtype(condition), bool_):
1583    condition = lax.ne(condition, zeros_like(condition))
1584  x, y = _promote_dtypes(x, y)
1585  condition, x, y = broadcast_arrays(condition, x, y)
1586  return lax.select(condition, x, y) if np.size(x) else x
1587
1588
1589_WHERE_DOC = """\
1590At present, JAX does not support JIT-compilation of the single-argument form
1591of :py:func:`jax.numpy.where` because its output shape is data-dependent. The
1592three-argument form does not have a data-dependent shape and can be JIT-compiled
1593successfully.
1594"""
1595
1596@_wraps(np.where, update_doc=False, lax_description=_WHERE_DOC)
1597def where(condition, x=None, y=None):
1598  if x is None and y is None:
1599    return nonzero(asarray(condition))
1600  else:
1601    return _where(condition, x, y)
1602
1603
1604@_wraps(np.select)
1605def select(condlist, choicelist, default=0):
1606  if len(condlist) != len(choicelist):
1607    msg = "condlist must have length equal to choicelist ({} vs {})"
1608    raise ValueError(msg.format(len(condlist), len(choicelist)))
1609  if len(condlist) == 0:
1610    raise ValueError("condlist must be non-empty")
1611  choices = _promote_dtypes(default, *choicelist)
1612  choicelist = choices[1:]
1613  output = choices[0]
1614  for cond, choice in zip(condlist[::-1], choicelist[::-1]):
1615    output = where(cond, choice, output)
1616  return output
1617
1618
1619@_wraps(np.bincount, lax_description="""\
1620Jax adds the optional `length` parameter which specifies the output length, and
1621defaults to ``x.max() + 1``. It must be specified for bincount to be compilable.
1622Values larger than the specified length will be discarded.
1623
1624Additionally, while ``np.bincount`` raises an error if the input array contains
1625negative values, ``jax.numpy.bincount`` treats negative values as zero.
1626""")
1627def bincount(x, weights=None, minlength=0, *, length=None):
1628  _check_arraylike("bincount", x)
1629  if not issubdtype(_dtype(x), integer):
1630    msg = f"x argument to bincount must have an integer type; got {x.dtype}"
1631    raise TypeError(msg)
1632  if length is None:
1633    x = core.concrete_or_error(asarray, x,
1634      "The error occured because of argument 'x' of jnp.bincount. "
1635      "To avoid this error, pass a static `length` argument.")
1636    length = max(x) + 1
1637  length = _max(length, minlength)
1638  if ndim(x) != 1:
1639    raise ValueError("only 1-dimensional input supported.")
1640  if weights is None:
1641    weights = array(1, dtype=int32)
1642  else:
1643    if shape(x) != shape(weights):
1644      raise ValueError("shape of weights must match shape of x.")
1645  return ops.index_add(zeros((length,), _dtype(weights)), ops.index[clip(x, 0)], weights)
1646
1647
1648def broadcast_arrays(*args):
1649  """Like Numpy's broadcast_arrays but doesn't return views."""
1650  shapes = [shape(arg) for arg in args]
1651  if len(set(shapes)) == 1:
1652    return [arg if isinstance(arg, ndarray) or isscalar(arg) else array(arg)
1653            for arg in args]
1654  result_shape = lax.broadcast_shapes(*shapes)
1655  return [broadcast_to(arg, result_shape) for arg in args]
1656
1657
1658@_wraps(np.broadcast_to, lax_description="""\
1659The JAX version does not necessarily return a view of the input.
1660""")
1661def broadcast_to(arr, shape):
1662  arr = arr if isinstance(arr, ndarray) else array(arr)
1663  shape = canonicalize_shape(shape)  # check that shape is concrete
1664  arr_shape = _shape(arr)
1665  if arr_shape == shape:
1666    return arr
1667  else:
1668    nlead = len(shape) - len(arr_shape)
1669    compatible = np.equal(arr_shape, shape[nlead:]) | np.equal(arr_shape, 1)
1670    if nlead < 0 or not np.all(compatible):
1671      msg = "Incompatible shapes for broadcasting: {} and requested shape {}"
1672      raise ValueError(msg.format(arr_shape, shape))
1673    diff, = np.where(np.not_equal(shape[nlead:], arr_shape))
1674    new_dims = tuple(range(nlead)) + tuple(nlead + diff)
1675    kept_dims = tuple(np.delete(np.arange(len(shape)), new_dims))
1676    return lax.broadcast_in_dim(squeeze(arr, tuple(diff)), shape, kept_dims)
1677
1678
1679def _split(op, ary, indices_or_sections, axis=0):
1680  axis = core.concrete_or_error(int, axis, f"in jax.numpy.{op} argument `axis`")
1681  size = ary.shape[axis]
1682  if isinstance(indices_or_sections, (tuple, list) + _arraylike_types):
1683    indices_or_sections = np.array(
1684        [core.concrete_or_error(np.int64, i_s, f"in jax.numpy.{op} argument 1")
1685         for i_s in indices_or_sections], np.int64)
1686    split_indices = np.concatenate([[np.int64(0)], indices_or_sections,
1687                                    [np.int64(size)]])
1688  else:
1689    indices_or_sections = core.concrete_or_error(np.int64, indices_or_sections,
1690                                                 f"in jax.numpy.{op} argument 1")
1691    part_size, r = _divmod(size, indices_or_sections)
1692    if r == 0:
1693      split_indices = np.arange(indices_or_sections + 1,
1694                                dtype=np.int64) * part_size
1695    elif op == "array_split":
1696      split_indices = np.concatenate(
1697          [np.arange(r + 1, dtype=np.int64) * (part_size + 1),
1698           np.arange(indices_or_sections - r, dtype=np.int64) * part_size
1699           + ((r + 1) * (part_size + 1) - 1)])
1700    else:
1701      raise ValueError("array split does not result in an equal division")
1702  starts, ends = [0] * ndim(ary), shape(ary)
1703  _subval = lambda x, i, v: subvals(x, [(i, v)])
1704  return [lax.slice(ary, _subval(starts, axis, start), _subval(ends, axis, end))
1705          for start, end in zip(split_indices[:-1], split_indices[1:])]
1706
1707@_wraps(np.split)
1708def split(ary, indices_or_sections, axis: int = 0):
1709  return _split("split", ary, indices_or_sections, axis=axis)
1710
1711def _split_on_axis(np_fun, axis):
1712  @_wraps(np_fun, update_doc=False)
1713  def f(ary, indices_or_sections):
1714    return split(ary, indices_or_sections, axis=axis)
1715  return f
1716
1717vsplit = _split_on_axis(np.vsplit, axis=0)
1718hsplit = _split_on_axis(np.hsplit, axis=1)
1719dsplit = _split_on_axis(np.dsplit, axis=2)
1720
1721@_wraps(np.array_split)
1722def array_split(ary, indices_or_sections, axis: int = 0):
1723  return _split("array_split", ary, indices_or_sections, axis=axis)
1724
1725@_wraps(np.clip)
1726def clip(a, a_min=None, a_max=None, out=None):
1727  _check_arraylike("clip", a)
1728  if out is not None:
1729    raise NotImplementedError("The 'out' argument to jnp.clip is not supported.")
1730  if a_min is None and a_max is None:
1731    raise ValueError("At most one of a_min and a_max may be None")
1732  if a_min is not None:
1733    a = maximum(a_min, a)
1734  if a_max is not None:
1735    a = minimum(a_max, a)
1736  return a
1737
1738@_wraps(np.round, update_doc=False)
1739def round(a, decimals=0, out=None):
1740  _check_arraylike("round", a)
1741  if out is not None:
1742    raise NotImplementedError("The 'out' argument to jnp.round is not supported.")
1743  dtype = _dtype(a)
1744  if issubdtype(dtype, integer):
1745    if decimals < 0:
1746      raise NotImplementedError(
1747        "integer np.round not implemented for decimals < 0")
1748    return a  # no-op on integer types
1749
1750  def _round_float(x):
1751    if decimals == 0:
1752      return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN)
1753
1754    # TODO(phawkins): the strategy of rescaling the value isn't necessarily a
1755    # good one since we may be left with an incorrectly rounded value at the
1756    # end due to precision problems. As a workaround for float16, convert to
1757    # float32,
1758    x = lax.convert_element_type(x, np.float32) if dtype == np.float16 else x
1759    factor = _constant_like(x, 10 ** decimals)
1760    out = lax.div(lax.round(lax.mul(x, factor),
1761                            lax.RoundingMethod.TO_NEAREST_EVEN), factor)
1762    return lax.convert_element_type(out, dtype) if dtype == np.float16 else out
1763
1764  if issubdtype(dtype, complexfloating):
1765    return lax.complex(_round_float(lax.real(a)), _round_float(lax.imag(a)))
1766  else:
1767    return _round_float(a)
1768around = round
1769
1770
1771@_wraps(np.fix)
1772def fix(x, out=None):
1773  _check_arraylike("fix", x)
1774  if out is not None:
1775    raise NotImplementedError("The 'out' argument to jnp.fix is not supported.")
1776  zero = lax._const(x, 0)
1777  return where(lax.ge(x, zero), floor(x), ceil(x))
1778
1779
1780@_wraps(np.modf)
1781def modf(x, out=None):
1782  _check_arraylike("modf", x)
1783  if out is not None:
1784    raise NotImplementedError("The 'out' argument to jnp.modf is not supported.")
1785  whole = fix(x)
1786  return x - whole, whole
1787
1788
1789@_wraps(np.isfinite)
1790def isfinite(x):
1791  _check_arraylike("isfinite", x)
1792  dtype = _dtype(x)
1793  if issubdtype(dtype, floating):
1794    return lax.is_finite(x)
1795  elif issubdtype(dtype, complexfloating):
1796    return lax.bitwise_and(lax.is_finite(real(x)), lax.is_finite(imag(x)))
1797  else:
1798    return full_like(x, True, dtype=bool_)
1799
1800@_wraps(np.isinf)
1801def isinf(x):
1802  _check_arraylike("isinf", x)
1803  dtype = _dtype(x)
1804  if issubdtype(dtype, floating):
1805    return lax.eq(lax.abs(x), _constant_like(x, inf))
1806  elif issubdtype(dtype, complexfloating):
1807    re = lax.real(x)
1808    im = lax.imag(x)
1809    return lax.bitwise_or(lax.eq(lax.abs(re), _constant_like(re, inf)),
1810                          lax.eq(lax.abs(im), _constant_like(im, inf)))
1811  else:
1812    return full_like(x, False, dtype=bool_)
1813
1814def _isposneginf(infinity, x, out):
1815  if out is not None:
1816    raise NotImplementedError("The 'out' argument to isneginf/isposinf is not supported.")
1817  dtype = _dtype(x)
1818  if issubdtype(dtype, floating):
1819    return lax.eq(x, _constant_like(x, infinity))
1820  elif issubdtype(dtype, complexfloating):
1821    raise ValueError("isposinf/isneginf are not well defined for complex types")
1822  else:
1823    return full_like(x, False, dtype=bool_)
1824
1825isposinf = _wraps(np.isposinf)(lambda x, out=None: _isposneginf(inf, x, out))
1826
1827isneginf = _wraps(np.isneginf)(lambda x, out=None: _isposneginf(-inf, x, out))
1828
1829@_wraps(np.isnan)
1830def isnan(x):
1831  _check_arraylike("isnan", x)
1832  return lax.ne(x, x)
1833
1834@_wraps(np.nan_to_num)
1835def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None):
1836  del copy
1837  _check_arraylike("nan_to_num", x)
1838  dtype = _dtype(x)
1839  if issubdtype(dtype, complexfloating):
1840    return lax.complex(
1841      nan_to_num(lax.real(x), nan=nan, posinf=posinf, neginf=neginf),
1842      nan_to_num(lax.imag(x), nan=nan, posinf=posinf, neginf=neginf))
1843  info = finfo(dtypes.canonicalize_dtype(dtype))
1844  posinf = info.max if posinf is None else posinf
1845  neginf = info.min if neginf is None else neginf
1846  x = where(isnan(x), _constant_like(x, nan), x)
1847  x = where(isposinf(x), _constant_like(x, posinf), x)
1848  x = where(isneginf(x), _constant_like(x, neginf), x)
1849  return x
1850
1851### Reducers
1852
1853def _reduction(a, name, np_fun, op, init_val, has_identity=True,
1854               preproc=None, bool_op=None, upcast_f16_for_computation=False,
1855               axis=None, dtype=None, out=None, keepdims=False, initial=None, where_=None):
1856  bool_op = bool_op or op
1857  if out is not None:
1858    raise NotImplementedError(f"The 'out' argument to jnp.{name} is not supported.")
1859  _check_arraylike(name, a)
1860  lax._check_user_dtype_supported(dtype, name)
1861  axis = core.concrete_or_error(None, axis, f"axis argument to jnp.{name}().")
1862
1863  if initial is None and not has_identity:
1864    if not size(a):
1865      raise ValueError(f"zero-size array to reduction operation {name} which has no identity")
1866    if where_ is not None:
1867      raise ValueError(f"reduction operation {name} does not have an identity, so to use a "
1868                       f"where mask one has to specify 'initial'")
1869
1870  a = a if isinstance(a, ndarray) else asarray(a)
1871  a = preproc(a) if preproc else a
1872  dims = _reduction_dims(a, axis)
1873  result_dtype = dtypes.canonicalize_dtype(dtype or _dtype(np_fun(np.ones((), dtype=_dtype(a)))))
1874  if upcast_f16_for_computation and issubdtype(result_dtype, inexact):
1875    computation_dtype = promote_types(result_dtype, float32)
1876  else:
1877    computation_dtype = result_dtype
1878  a = lax.convert_element_type(a, computation_dtype)
1879  op = op if computation_dtype != np.bool_ else bool_op
1880  # NB: in XLA, init_val must be an identity for the op, so the user-specified
1881  # initial value must be applied afterward.
1882  init_val = _reduction_init_val(a, init_val)
1883  if where_ is not None:
1884    a = where(where_, a, init_val)
1885  result = lax.reduce(a, init_val, op, dims)
1886  if initial is not None:
1887    result = op(_reduction_init_val(a, initial), result)
1888  if keepdims:
1889    result = expand_dims(result, dims)
1890  return lax.convert_element_type(result, dtype or result_dtype)
1891
1892def _reduction_dims(a, axis):
1893  if axis is None:
1894    return tuple(range(ndim(a)))
1895  elif isinstance(axis, (np.ndarray, tuple, list)):
1896    axis = tuple(_canonicalize_axis(x, ndim(a)) for x in axis)
1897    if len(axis) != len(set(axis)):
1898      raise ValueError(f"duplicate value in 'axis': {axis}")
1899    return axis
1900  else:
1901    return (_canonicalize_axis(axis, ndim(a)),)
1902
1903def _reduction_init_val(a, init_val):
1904  a_dtype = dtypes.canonicalize_dtype(_dtype(a))
1905  if a_dtype == 'bool':
1906    return np.array(init_val > 0, dtype=a_dtype)
1907  try:
1908    return np.array(init_val, dtype=a_dtype)
1909  except OverflowError:
1910    assert issubdtype(a_dtype, integer)
1911    sign, info = np.sign(init_val), iinfo(a_dtype)
1912    return np.array(info.min if sign < 0 else info.max, dtype=a_dtype)
1913
1914_cast_to_bool = partial(lax.convert_element_type, new_dtype=bool_)
1915
1916@_wraps(np.sum)
1917def sum(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
1918        out=None, keepdims=None, initial=None, where=None):
1919  return _reduction(a, "sum", np.sum, lax.add, 0,
1920                    bool_op=lax.bitwise_or, upcast_f16_for_computation=True,
1921                    axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where_=where)
1922
1923@_wraps(np.prod)
1924def prod(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
1925         out=None, keepdims=None, initial=None, where=None):
1926  return _reduction(a, "prod", np.prod, lax.mul, 1,
1927                    bool_op=lax.bitwise_and, upcast_f16_for_computation=True,
1928                    axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where_=where)
1929
1930@_wraps(np.max)
1931def max(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
1932        keepdims=None, initial=None, where=None):
1933  return _reduction(a, "max", np.max, lax.max, -np.inf, has_identity=False,
1934                    axis=axis, out=out, keepdims=keepdims, initial=initial, where_=where)
1935
1936@_wraps(np.min)
1937def min(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
1938        keepdims=None, initial=None, where=None):
1939  return _reduction(a, "min", np.min, lax.min, np.inf, has_identity=False,
1940                    axis=axis, out=out, keepdims=keepdims, initial=initial, where_=where)
1941
1942@_wraps(np.all)
1943def all(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
1944        keepdims=None):
1945  return _reduction(a, "all", np.all, lax.bitwise_and, True, preproc=_cast_to_bool,
1946                    axis=axis, out=out, keepdims=keepdims)
1947
1948@_wraps(np.any)
1949def any(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
1950        keepdims=None):
1951  return _reduction(a, "any", np.any, lax.bitwise_or, False, preproc=_cast_to_bool,
1952                    axis=axis, out=out, keepdims=keepdims)
1953
1954product = prod
1955amin = min
1956amax = max
1957alltrue = all
1958sometrue = any
1959
1960@_wraps(np.mean)
1961def mean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
1962         out=None, keepdims=False):
1963  _check_arraylike("mean", a)
1964  lax._check_user_dtype_supported(dtype, "mean")
1965  if out is not None:
1966    raise NotImplementedError("The 'out' argument to jnp.mean is not supported.")
1967
1968  if axis is None:
1969    normalizer = size(a)
1970  else:
1971    normalizer = np.prod(np.take(shape(a), axis))
1972  if dtype is None:
1973    if issubdtype(_dtype(a), bool_) or issubdtype(_dtype(a), integer):
1974      dtype = float_
1975    else:
1976      dtype = _dtype(a)
1977  dtype = dtypes.canonicalize_dtype(dtype)
1978
1979  return lax.div(
1980      sum(a, axis, dtype=dtype, keepdims=keepdims),
1981      lax.convert_element_type(normalizer, dtype))
1982
1983@_wraps(np.average)
1984def average(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, weights=None,
1985            returned=False):
1986  a = asarray(a)
1987
1988  if weights is None: # Treat all weights as 1
1989    avg = mean(a, axis=axis)
1990    if axis is None:
1991      weights_sum = full((), size(a), dtype=avg.dtype)
1992    else:
1993      weights_sum = full_like(avg, a.shape[axis], dtype=avg.dtype)
1994  else:
1995    weights = asarray(weights)
1996
1997    if issubdtype(a.dtype, inexact):
1998      out_dtype = result_type(a.dtype, weights.dtype)
1999    else:
2000      out_dtype = result_type(a.dtype, weights.dtype, float_)
2001    out_dtype = dtypes.canonicalize_dtype(out_dtype)
2002
2003    a_shape = shape(a)
2004    a_ndim = len(a_shape)
2005    weights_shape = shape(weights)
2006    axis = None if axis is None else _canonicalize_axis(axis, a_ndim)
2007
2008    if a_shape != weights_shape:
2009      # Make sure the dimensions work out
2010      if axis is None:
2011        raise ValueError("Axis must be specified when shapes of a and "
2012                         "weights differ.")
2013      if len(weights_shape) != 1:
2014        raise ValueError("1D weights expected when shapes of a and "
2015                         "weights differ.")
2016      if weights_shape[0] != a_shape[axis]:
2017        raise ValueError("Length of weights not "
2018                         "compatible with specified axis.")
2019
2020      weights = broadcast_to(weights, (a_ndim - 1) * (1,) + weights_shape)
2021      weights = moveaxis(weights, -1, axis)
2022
2023    weights_sum = sum(weights, axis=axis, dtype=out_dtype)
2024    avg = sum(multiply(a, weights), axis=axis, dtype=out_dtype) / weights_sum
2025
2026  if returned:
2027    if avg.shape != weights_sum.shape:
2028      weights_sum = broadcast_to(weights_sum, avg.shape)
2029    return avg, weights_sum
2030  return avg
2031
2032
2033@_wraps(np.var)
2034def var(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
2035        out=None, ddof=0, keepdims=False):
2036  _check_arraylike("var", a)
2037  lax._check_user_dtype_supported(dtype, "var")
2038  if out is not None:
2039    raise NotImplementedError("The 'out' argument to jnp.var is not supported.")
2040
2041  a_dtype, dtype = _var_promote_types(_dtype(a), dtype)
2042  a_mean = mean(a, axis, dtype=a_dtype, keepdims=True)
2043  centered = a - a_mean
2044  if issubdtype(centered.dtype, complexfloating):
2045    centered = lax.real(lax.mul(centered, lax.conj(centered)))
2046  else:
2047    centered = lax.square(centered)
2048
2049  if axis is None:
2050    normalizer = size(a)
2051  else:
2052    normalizer = np.prod(np.take(shape(a), axis))
2053  normalizer = normalizer - ddof
2054
2055  result = sum(centered, axis, keepdims=keepdims)
2056  out = lax.div(result, lax.convert_element_type(normalizer, result.dtype))
2057  return lax.convert_element_type(out, dtype)
2058
2059
2060def _var_promote_types(a_dtype, dtype):
2061  if dtype:
2062    if (not issubdtype(dtype, complexfloating) and
2063        issubdtype(a_dtype, complexfloating)):
2064      msg = ("jax.numpy.var does not yet support real dtype parameters when "
2065             "computing the variance of an array of complex values. The "
2066             "semantics of numpy.var seem unclear in this case. Please comment "
2067             "on https://github.com/google/jax/issues/2283 if this behavior is "
2068             "important to you.")
2069      raise ValueError(msg)
2070    a_dtype = promote_types(a_dtype, dtype)
2071  else:
2072    if not issubdtype(a_dtype, inexact):
2073      dtype = a_dtype = dtypes.canonicalize_dtype(float_)
2074    else:
2075      dtype = _complex_elem_type(a_dtype)
2076      a_dtype = promote_types(a_dtype, float32)
2077  return a_dtype, dtype
2078
2079
2080@_wraps(np.std)
2081def std(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
2082        out=None, ddof=0, keepdims=False):
2083  _check_arraylike("std", a)
2084  lax._check_user_dtype_supported(dtype, "std")
2085  if out is not None:
2086    raise NotImplementedError("The 'out' argument to jnp.std is not supported.")
2087  return sqrt(var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims))
2088
2089
2090@_wraps(np.ptp)
2091def ptp(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
2092        keepdims=False):
2093  _check_arraylike("ptp", a)
2094  if out is not None:
2095    raise NotImplementedError("The 'out' argument to jnp.ptp is not supported.")
2096  x = amax(a, axis=axis, keepdims=keepdims)
2097  y = amin(a, axis=axis, keepdims=keepdims)
2098  return lax.sub(x, y)
2099
2100
2101@_wraps(np.allclose)
2102def allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
2103  return all(isclose(a, b, rtol, atol, equal_nan))
2104
2105
2106@_wraps(np.count_nonzero)
2107def count_nonzero(a, axis: Optional[Union[int, Tuple[int, ...]]] = None,
2108                  keepdims=False):
2109  _check_arraylike("count_nonzero", a)
2110  return sum(lax.ne(a, _constant_like(a, 0)), axis=axis,
2111             dtype=dtypes.canonicalize_dtype(np.int_), keepdims=keepdims)
2112
2113
2114_NONZERO_DOC = """\
2115At present, JAX does not support JIT-compilation of :py:func:`jax.numpy.nonzero`
2116because its output shape is data-dependent.
2117"""
2118
2119@_wraps(np.nonzero, lax_description=_NONZERO_DOC)
2120def nonzero(a):
2121  # Note: this function cannot be jitted because its output has a dynamic
2122  # shape.
2123  a = core.concrete_or_error(atleast_1d, a, "The error arose in jnp.nonzero")
2124  dims = shape(a)
2125  ndims = len(dims)
2126  ds = [lax.broadcasted_iota(int_, dims + (1,), i) for i in range(ndims)]
2127  d = concatenate(ds, axis=-1)
2128  indexes = d[a != 0]
2129  return tuple(indexes[..., i] for i in range(ndims))
2130
2131
2132@_wraps(np.flatnonzero)
2133def flatnonzero(a):
2134  return nonzero(ravel(a))[0]
2135
2136
2137def _nan_reduction(a, name, jnp_reduction, init_val, nan_if_all_nan,
2138                   axis=None, keepdims=None, **kwargs):
2139  _check_arraylike(name, a)
2140  out = jnp_reduction(where(isnan(a), _reduction_init_val(a, init_val), a),
2141                      axis=axis, keepdims=keepdims, **kwargs)
2142  if nan_if_all_nan:
2143    return where(all(isnan(a), axis=axis, keepdims=keepdims),
2144                  _constant_like(a, nan), out)
2145  else:
2146    return out
2147
2148@_wraps(np.nanmin)
2149def nanmin(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
2150           keepdims=None):
2151  return _nan_reduction(a, 'nanmin', min, inf, nan_if_all_nan=True,
2152                        axis=axis, out=out, keepdims=keepdims)
2153
2154@_wraps(np.nanmax)
2155def nanmax(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
2156           keepdims=None):
2157  return _nan_reduction(a, 'nanmax', max, -inf, nan_if_all_nan=True,
2158                        axis=axis, out=out, keepdims=keepdims)
2159
2160@_wraps(np.nansum)
2161def nansum(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
2162           out=None, keepdims=None):
2163  return _nan_reduction(a, 'nansum', sum, 0, nan_if_all_nan=False,
2164                        axis=axis, dtype=dtype, out=out, keepdims=keepdims)
2165
2166@_wraps(np.nanprod)
2167def nanprod(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
2168            out=None, keepdims=None):
2169  return _nan_reduction(a, 'nanprod', prod, 1, nan_if_all_nan=False,
2170                        axis=axis, dtype=dtype, out=out, keepdims=keepdims)
2171
2172@_wraps(np.nanmean)
2173def nanmean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
2174            out=None, keepdims=False):
2175  _check_arraylike("nanmean", a)
2176  lax._check_user_dtype_supported(dtype, "nanmean")
2177  if out is not None:
2178    raise NotImplementedError("The 'out' argument to jnp.nanmean is not supported.")
2179  if issubdtype(_dtype(a), bool_) or issubdtype(_dtype(a), integer):
2180    return mean(a, axis, dtype, out, keepdims)
2181  if dtype is None:
2182    dtype = _dtype(a)
2183  nan_mask = logical_not(isnan(a))
2184  normalizer = sum(nan_mask, axis=axis, dtype=int32, keepdims=keepdims)
2185  normalizer = lax.convert_element_type(normalizer, dtype)
2186  td = lax.div(nansum(a, axis, dtype=dtype, keepdims=keepdims), normalizer)
2187  return td
2188
2189
2190@_wraps(np.nanvar)
2191def nanvar(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
2192           out=None, ddof=0, keepdims=False):
2193  _check_arraylike("nanvar", a)
2194  lax._check_user_dtype_supported(dtype, "nanvar")
2195  if out is not None:
2196    raise NotImplementedError("The 'out' argument to jnp.nanvar is not supported.")
2197
2198  a_dtype, dtype = _var_promote_types(_dtype(a), dtype)
2199  a_mean = nanmean(a, axis, dtype=a_dtype, keepdims=True)
2200  centered = a - a_mean
2201  if issubdtype(centered.dtype, complexfloating):
2202    centered = lax.real(lax.mul(centered, lax.conj(centered)))
2203  else:
2204    centered = lax.square(centered)
2205
2206  normalizer = sum(logical_not(isnan(a)), axis=axis, keepdims=keepdims)
2207  normalizer = normalizer - ddof
2208  if config.omnistaging_enabled:
2209    normalizer_mask = lax.le(normalizer, 0)
2210  else:
2211    zero = lax.full_like(normalizer, 0, shape=())
2212    normalizer_mask = lax.le(normalizer, zero)
2213
2214  result = nansum(centered, axis, keepdims=keepdims)
2215  result = where(normalizer_mask, nan, result)
2216  divisor = where(normalizer_mask, 1, normalizer)
2217  out = lax.div(result, lax.convert_element_type(divisor, result.dtype))
2218  return lax.convert_element_type(out, dtype)
2219
2220
2221@_wraps(np.nanstd)
2222def nanstd(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None,
2223           out=None, ddof=0, keepdims=False):
2224  _check_arraylike("nanstd", a)
2225  lax._check_user_dtype_supported(dtype, "nanstd")
2226  if out is not None:
2227    raise NotImplementedError("The 'out' argument to jnp.nanstd is not supported.")
2228  return sqrt(nanvar(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims))
2229
2230
2231def _make_cumulative_reduction(np_reduction, reduction, fill_nan=False, fill_value=0):
2232  # We want to allow XLA to fuse the pad and reduce-window operators to
2233  # avoid materializing the padded output.
2234  # Consider removing `jit` once again if reduce-window is generalized to
2235  # support arbitrary padding.
2236  @partial(jit, static_argnums=(1, 2))
2237  def _cumulative_reduction(a, axis, dtype):
2238    if axis is None or isscalar(a):
2239      a = ravel(a)
2240      axis = 0
2241
2242    a_shape = list(shape(a))
2243    num_dims = len(a_shape)
2244    axis = _canonicalize_axis(axis, num_dims)
2245
2246    if fill_nan:
2247      a = where(isnan(a), _constant_like(a, fill_value), a)
2248
2249    if not dtype and _dtype(a) == bool_:
2250      dtype = int_
2251    if dtype:
2252      a = lax.convert_element_type(a, dtype)
2253
2254    return reduction(a, axis)
2255
2256  @_wraps(np_reduction)
2257  def cumulative_reduction(a,
2258                           axis: Optional[Union[int, Tuple[int, ...]]] = None,
2259                           dtype=None, out=None):
2260    _check_arraylike(np_reduction.__name__, a)
2261    if out is not None:
2262      raise NotImplementedError(f"The 'out' argument to jnp.{np_reduction.__name__} "
2263                                f"is not supported.")
2264    lax._check_user_dtype_supported(dtype, np_reduction.__name__)
2265    # jit doesn't support kwargs as static_args.
2266    return _cumulative_reduction(a, axis, dtype)
2267  return cumulative_reduction
2268
2269
2270cumsum = _make_cumulative_reduction(np.cumsum, lax.cumsum, fill_nan=False)
2271cumprod = _make_cumulative_reduction(np.cumprod, lax.cumprod, fill_nan=False)
2272cumproduct = cumprod
2273nancumsum = _make_cumulative_reduction(np.nancumsum, lax.cumsum,
2274                                       fill_nan=True, fill_value=0)
2275nancumprod = _make_cumulative_reduction(np.nancumprod, lax.cumprod,
2276                                        fill_nan=True, fill_value=1)
2277
2278
2279@_wraps(np.unwrap)
2280def unwrap(p, discont=pi, axis: int = -1):
2281  _check_arraylike("unwrap", p)
2282  dd = diff(p, axis=axis)
2283  ddmod = mod(dd + pi, 2 * pi) - pi
2284  ddmod = where((ddmod == -pi) & (dd > 0), pi, ddmod)
2285
2286  ph_correct = where(abs(dd) < discont, 0, ddmod - dd)
2287
2288  up = concatenate((
2289    lax.slice_in_dim(p, 0, 1, axis=axis),
2290    lax.slice_in_dim(p, 1, None, axis=axis) + cumsum(ph_correct, axis=axis)
2291  ), axis=axis)
2292
2293  return up
2294
2295
2296### Array-creation functions
2297
2298def _check_no_padding(axis_padding, mode):
2299  if (axis_padding[0] > 0 or axis_padding[1] > 0):
2300    msg = "Cannot apply '{}' padding to empty axis"
2301    raise ValueError(msg.format(mode))
2302
2303
2304def _pad_constant(array, pad_width, constant_values):
2305  nd = ndim(array)
2306  constant_values = broadcast_to(asarray(constant_values), (nd, 2))
2307  constant_values = lax.convert_element_type(constant_values, array.dtype)
2308  for i in range(nd):
2309    widths = [(0, 0, 0)] * nd
2310    widths[i] = (pad_width[i, 0], 0, 0)
2311    array = lax.pad(array, constant_values[i, 0], widths)
2312    widths[i] = (0, pad_width[i, 1], 0)
2313    array = lax.pad(array, constant_values[i, 1], widths)
2314  return array
2315
2316
2317def _pad_wrap(array, pad_width):
2318  for i in range(ndim(array)):
2319    if array.shape[i] == 0:
2320      _check_no_padding(pad_width[i], "wrap")
2321      continue
2322    size = array.shape[i]
2323    repeats, (left_remainder, right_remainder) = _divmod(pad_width[i], size)
2324    total_repeats = repeats.sum() + 1
2325    parts = []
2326    if left_remainder:
2327      parts += [lax.slice_in_dim(array, size - left_remainder, size, axis=i)]
2328    parts += total_repeats * [array]
2329    if right_remainder:
2330      parts += [lax.slice_in_dim(array, 0, right_remainder, axis=i)]
2331    array = lax.concatenate(parts, dimension=i)
2332  return array
2333
2334
2335def _pad_symmetric_or_reflect(array, pad_width, mode, reflect_type):
2336  assert mode in ("symmetric", "reflect")
2337  assert reflect_type in ("even", "odd")
2338
2339  for i in range(ndim(array)):
2340    if array.shape[i] == 0:
2341      _check_no_padding(pad_width[i], mode)
2342      continue
2343
2344    n = array.shape[i]
2345    offset = 1 if (mode == "reflect" and n > 1) else 0
2346
2347    def build_padding(array, padding, before):
2348      if before:
2349        edge = lax.slice_in_dim(array, 0, 1, axis=i)
2350      else:
2351        edge = lax.slice_in_dim(array, -1, None, axis=i)
2352
2353      while padding > 0:
2354        curr_pad = _min(padding, n - offset)
2355        padding -= curr_pad
2356
2357        if before:
2358          start = offset
2359          stop = offset + curr_pad
2360        else:
2361          start = -(curr_pad + offset)
2362          stop = None if (mode == "symmetric" or n == 1) else -1
2363
2364        x = lax.slice_in_dim(array, start, stop, axis=i)
2365        x = flip(x, axis=i)
2366
2367        if reflect_type == 'odd':
2368          x = 2 * edge - x
2369          if n > 1:
2370            if before:
2371              edge = lax.slice_in_dim(x, 0, 1, axis=i)
2372            else:
2373              edge = lax.slice_in_dim(x, -1, None, axis=i)
2374
2375        if before:
2376          array = lax.concatenate([x, array], dimension=i)
2377        else:
2378          array = lax.concatenate([array, x], dimension=i)
2379      return array
2380
2381    array = build_padding(array, pad_width[i, 0], before=True)
2382    array = build_padding(array, pad_width[i, 1], before=False)
2383  return array
2384
2385
2386def _pad_edge(array, pad_width):
2387  nd = ndim(array)
2388  for i in range(nd):
2389    if array.shape[i] == 0:
2390      _check_no_padding(pad_width[i], "edge")
2391      continue
2392
2393    n = array.shape[i]
2394    npad_before, npad_after = pad_width[i]
2395
2396    edge_before = lax.slice_in_dim(array, 0, 1, axis=i)
2397    pad_before = repeat(edge_before, npad_before, axis=i)
2398
2399    edge_after = lax.slice_in_dim(array, n-1, n, axis=i)
2400    pad_after = repeat(edge_after, npad_after, axis=i)
2401
2402    array = lax.concatenate([pad_before, array, pad_after], dimension=i)
2403  return array
2404
2405
2406def _pad_linear_ramp(array, pad_width, end_values):
2407  for axis in range(ndim(array)):
2408    edge_before = lax.slice_in_dim(array, 0, 1, axis=axis)
2409    edge_after = lax.slice_in_dim(array, -1, None, axis=axis)
2410    ramp_before = linspace(
2411        start=end_values[axis][0],
2412        stop=edge_before.squeeze(axis), # Dimension is replaced by linspace
2413        num=pad_width[axis][0],
2414        endpoint=False,
2415        dtype=array.dtype,
2416        axis=axis
2417    )
2418    ramp_after = linspace(
2419        start=end_values[axis][1],
2420        stop=edge_after.squeeze(axis), # Dimension is replaced by linspace
2421        num=pad_width[axis][1],
2422        endpoint=False,
2423        dtype=array.dtype,
2424        axis=axis
2425    )
2426
2427    # Reverse linear space in appropriate dimension
2428    ramp_after = flip(ramp_after, axis)
2429
2430    array = lax.concatenate([ramp_before, array, ramp_after], dimension=axis)
2431  return array
2432
2433
2434def _pad_stats(array, pad_width, stat_length, stat_func):
2435  nd = ndim(array)
2436  for i in range(nd):
2437    if stat_length is None:
2438      stat_before = stat_func(array, axis=i, keepdims=True)
2439      stat_after = stat_before
2440    else:
2441      array_length = array.shape[i]
2442      length_before, length_after = stat_length[i]
2443      if length_before == 0 or length_after == 0:
2444        raise ValueError("stat_length of 0 yields no value for padding")
2445
2446      # Limit stat_length to length of array.
2447      length_before = _min(length_before, array_length)
2448      length_after = _min(length_after, array_length)
2449
2450      slice_before = lax.slice_in_dim(array, 0, length_before, axis=i)
2451      slice_after = lax.slice_in_dim(array, -length_after, None, axis=i)
2452      stat_before = stat_func(slice_before, axis=i, keepdims=True)
2453      stat_after = stat_func(slice_after, axis=i, keepdims=True)
2454
2455    if np.issubdtype(array.dtype, np.integer):
2456      stat_before = round(stat_before)
2457      stat_after = round(stat_after)
2458
2459    stat_before = stat_before.astype(array.dtype)
2460    stat_after = stat_after.astype(array.dtype)
2461
2462    npad_before, npad_after = pad_width[i]
2463    pad_before = repeat(stat_before, npad_before, axis=i)
2464    pad_after = repeat(stat_after, npad_after, axis=i)
2465
2466    array = lax.concatenate([pad_before, array, pad_after], dimension=i)
2467  return array
2468
2469
2470def _pad_empty(array, pad_width):
2471  # Note: jax.numpy.empty = jax.numpy.zeros
2472  for i in range(ndim(array)):
2473    shape_before = array.shape[:i] + (pad_width[i][0],) + array.shape[i + 1:]
2474    pad_before = empty(shape_before, dtype=array.dtype)
2475
2476    shape_after = array.shape[:i] + (pad_width[i][1],) + array.shape[i + 1:]
2477    pad_after = empty(shape_after, dtype=array.dtype)
2478    array = lax.concatenate([pad_before, array, pad_after], dimension=i)
2479  return array
2480
2481
2482def _pad_func(array, pad_width, func, **kwargs):
2483  pad_width = _broadcast_to_pairs(pad_width, ndim(array), "pad_width")
2484  padded = _pad_constant(array, np.array(pad_width), 0)
2485  for axis in range(ndim(padded)):
2486    padded = apply_along_axis(func, axis, padded, pad_width[axis], axis, kwargs)
2487  return padded
2488
2489
2490def _broadcast_to_pairs(nvals, nd, name):
2491  nvals_shape = np.shape(nvals)
2492  if nvals_shape == (nd, 2):
2493    # ((before_1, after_1), ..., (before_N, after_N))
2494    pass
2495  elif nvals_shape == (1, 2):
2496    # ((before, after),)
2497    nvals = nvals * nd
2498  elif nvals_shape == (2,):
2499    # (before, after)  (not in the numpy docstring but works anyway)
2500    before, after = nvals
2501    nvals = (nvals,) * nd
2502  elif nvals_shape == (1,):
2503    # (pad,)
2504    nvals, = nvals
2505    nvals = ((nvals, nvals),) * nd
2506  elif nvals_shape == ():
2507    # pad
2508    nvals = ((nvals, nvals),) * nd
2509  else:
2510    raise ValueError(f"{name} given unexpected structure: {nvals}. "
2511                     "See docstring for valid {name} formats.")
2512  return nvals
2513
2514
2515@partial(jit, static_argnums=(1, 2, 4, 5, 6))
2516def _pad(array, pad_width, mode, constant_values, stat_length, end_values, reflect_type):
2517  array = asarray(array)
2518  nd = ndim(array)
2519
2520  if nd == 0:
2521    return array
2522
2523  stat_funcs = {"maximum": amax, "minimum": amin,
2524                "mean": mean, "median": median}
2525
2526  pad_width = _broadcast_to_pairs(pad_width, nd, "pad_width")
2527  pad_width = np.array(pad_width)
2528  assert pad_width.shape == (nd, 2), pad_width
2529
2530  if np.any(pad_width < 0):
2531    raise ValueError("index can't contain negative values")
2532
2533  if mode == "constant":
2534    return _pad_constant(array, pad_width, constant_values)
2535
2536  elif mode == "wrap":
2537    return _pad_wrap(array, pad_width)
2538
2539  elif mode in ("symmetric", "reflect"):
2540    return _pad_symmetric_or_reflect(array, pad_width, mode, reflect_type)
2541
2542  elif mode == "edge":
2543    return _pad_edge(array, pad_width)
2544
2545  elif mode == "linear_ramp":
2546    end_values = _broadcast_to_pairs(end_values, nd, "end_values")
2547    return _pad_linear_ramp(array, pad_width, end_values)
2548
2549  elif mode in stat_funcs:
2550    if stat_length is not None:
2551      stat_length = _broadcast_to_pairs(stat_length, nd, "stat_length")
2552    return _pad_stats(array, pad_width, stat_length, stat_funcs[mode])
2553
2554  elif mode == "empty":
2555    return _pad_empty(array, pad_width)
2556
2557  else:
2558    assert False, ("Should not be reached since pad already handled unsupported and"
2559                   "not implemented modes")
2560
2561
2562@_wraps(np.pad, lax_description="""\
2563Unlike numpy, JAX "function" mode's argument (which is another function) should return
2564the modified array. This is because Jax arrays are immutable.
2565(In numpy, "function" mode's argument should modify a rank 1 array in-place.)
2566""")
2567def pad(array, pad_width, mode="constant", **kwargs):
2568  if isinstance(pad_width, Iterable):
2569    pad_width = tuple(
2570        tuple(int(i) for i in x) if isinstance(x, Iterable) else x
2571        for x in pad_width)
2572
2573  if callable(mode):
2574    return _pad_func(array, pad_width, mode, **kwargs)
2575
2576  allowed_kwargs = {
2577      'empty': [], 'edge': [], 'wrap': [],
2578      'constant': ['constant_values'],
2579      'linear_ramp': ['end_values'],
2580      'maximum': ['stat_length'],
2581      'mean': ['stat_length'],
2582      'median': ['stat_length'],
2583      'minimum': ['stat_length'],
2584      'reflect': ['reflect_type'],
2585      'symmetric': ['reflect_type'],
2586  }
2587  try:
2588    unsupported_kwargs = set(kwargs) - set(allowed_kwargs[mode])
2589  except KeyError:
2590    msg = "Unimplemented padding mode '{}' for np.pad."
2591    raise NotImplementedError(msg.format(mode))
2592  if unsupported_kwargs:
2593    raise ValueError("unsupported keyword arguments for mode '{}': {}"
2594                     .format(mode, unsupported_kwargs))
2595  # Set default value if not given.
2596  constant_values = kwargs.get('constant_values', 0)
2597  stat_length = kwargs.get('stat_length', None)
2598  end_values = kwargs.get('end_values', 0)
2599  reflect_type = kwargs.get('reflect_type', "even")
2600
2601  return _pad(array, pad_width, mode, constant_values, stat_length, end_values, reflect_type)
2602
2603
2604@_wraps(np.stack)
2605def stack(arrays, axis: int =0, out=None):
2606  if not len(arrays):
2607    raise ValueError("Need at least one array to stack.")
2608  if out is not None:
2609    raise NotImplementedError("The 'out' argument to jnp.stack is not supported.")
2610  _check_arraylike("stack", *arrays)
2611  shape0 = shape(arrays[0])
2612  axis = _canonicalize_axis(axis, len(shape0) + 1)
2613  new_arrays = []
2614  for a in arrays:
2615    if shape(a) != shape0:
2616      raise ValueError("All input arrays must have the same shape.")
2617    new_arrays.append(expand_dims(a, axis))
2618  return concatenate(new_arrays, axis=axis)
2619
2620@_wraps(np.tile)
2621def tile(A, reps):
2622  _check_arraylike("tile", A)
2623  try:
2624    iter(reps)
2625  except TypeError:
2626    reps = (reps,)
2627  reps = tuple(operator.index(rep) for rep in reps)
2628  A_shape = (1,) * (len(reps) - ndim(A)) + shape(A)
2629  reps = (1,) * (len(A_shape) - len(reps)) + reps
2630  result = broadcast_to(reshape(A, [j for i in A_shape for j in [1, i]]),
2631                        [k for pair in zip(reps, A_shape) for k in pair])
2632  return reshape(result, tuple(np.multiply(A_shape, reps)))
2633
2634@_wraps(np.concatenate)
2635def concatenate(arrays, axis: int = 0):
2636  _check_arraylike("concatenate", *arrays)
2637  if not len(arrays):
2638    raise ValueError("Need at least one array to concatenate.")
2639  if ndim(arrays[0]) == 0:
2640    raise ValueError("Zero-dimensional arrays cannot be concatenated.")
2641  if axis is None:
2642    return concatenate([ravel(a) for a in arrays], axis=0)
2643  axis = _canonicalize_axis(axis, ndim(arrays[0]))
2644  arrays = _promote_dtypes(*arrays)
2645  # lax.concatenate can be slow to compile for wide concatenations, so form a
2646  # tree of concatenations as a workaround especially for op-by-op mode.
2647  # (https://github.com/google/jax/issues/653).
2648  k = 16
2649  if len(arrays) == 1:
2650    return asarray(arrays[0])
2651  else:
2652    while len(arrays) > 1:
2653      arrays = [lax.concatenate(arrays[i:i+k], axis)
2654                for i in range(0, len(arrays), k)]
2655    return arrays[0]
2656
2657
2658@_wraps(np.vstack)
2659def vstack(tup):
2660  return concatenate([atleast_2d(m) for m in tup], axis=0)
2661row_stack = vstack
2662
2663
2664@_wraps(np.hstack)
2665def hstack(tup):
2666  arrs = [atleast_1d(m) for m in tup]
2667  if arrs[0].ndim == 1:
2668    return concatenate(arrs, 0)
2669  return concatenate(arrs, 1)
2670
2671
2672@_wraps(np.dstack)
2673def dstack(tup):
2674  return concatenate([atleast_3d(m) for m in tup], axis=2)
2675
2676
2677@_wraps(np.column_stack)
2678def column_stack(tup):
2679  arrays = []
2680  for v in tup:
2681    arr = asarray(v)
2682    if arr.ndim < 2:
2683      arr = atleast_2d(arr).T
2684    arrays.append(arr)
2685  return concatenate(arrays, 1)
2686
2687
2688@_wraps(np.choose)
2689def choose(a, choices, out=None, mode='raise'):
2690  if out is not None:
2691    raise NotImplementedError("The 'out' argument to jnp.choose is not supported.")
2692  _check_arraylike('choose', a, *choices)
2693  if not issubdtype(_dtype(a), integer):
2694    raise ValueError("`a` array must be integer typed")
2695  N = len(choices)
2696
2697  if mode == 'raise':
2698    a = core.concrete_or_error(asarray, a,
2699      "The error occurred because jnp.choose was jit-compiled"
2700      " with mode='raise'. Use mode='wrap' or mode='clip' instead.")
2701    if any((a < 0) | (a >= N)):
2702      raise ValueError("invalid entry in choice array")
2703  elif mode == 'wrap':
2704    a = a % N
2705  elif mode == 'clip':
2706    a = clip(a, 0, N - 1)
2707  else:
2708    raise ValueError(f"mode={mode!r} not understood. Must be 'raise', 'wrap', or 'clip'")
2709
2710  a, *choices = broadcast_arrays(a, *choices)
2711  return array(choices)[(a,) + indices(a.shape, sparse=True)]
2712
2713
2714def _atleast_nd(x, n):
2715  m = ndim(x)
2716  return lax.broadcast(x, (1,) * (n - m)) if m < n else x
2717
2718def _block(xs):
2719  if isinstance(xs, tuple):
2720    raise ValueError("jax.numpy.block does not allow tuples, got {}"
2721                     .format(xs))
2722  elif isinstance(xs, list):
2723    if len(xs) == 0:
2724      raise ValueError("jax.numpy.block does not allow empty list arguments")
2725    xs, depths = unzip2([_block(x) for x in xs])
2726    if _any(d != depths[0] for d in depths[1:]):
2727      raise ValueError("Mismatched list depths in jax.numpy.block")
2728    rank = _max(depths[0], _max(ndim(x) for x in xs))
2729    xs = [_atleast_nd(x, rank) for x in xs]
2730    return concatenate(xs, axis=-depths[0]), depths[0] + 1
2731  else:
2732    return asarray(xs), 1
2733
2734@_wraps(np.block)
2735@jit
2736def block(arrays):
2737  out, _ = _block(arrays)
2738  return out
2739
2740
2741@_wraps(np.atleast_1d, update_doc=False)
2742def atleast_1d(*arys):
2743  if len(arys) == 1:
2744    arr = asarray(arys[0])
2745    return arr if ndim(arr) >= 1 else reshape(arr, -1)
2746  else:
2747    return [atleast_1d(arr) for arr in arys]
2748
2749
2750@_wraps(np.atleast_2d, update_doc=False)
2751def atleast_2d(*arys):
2752  if len(arys) == 1:
2753    arr = asarray(arys[0])
2754    if ndim(arr) >= 2:
2755      return arr
2756    elif ndim(arr) == 1:
2757      return expand_dims(arr, axis=0)
2758    else:
2759      return expand_dims(arr, axis=(0, 1))
2760  else:
2761    return [atleast_2d(arr) for arr in arys]
2762
2763
2764@_wraps(np.atleast_3d, update_doc=False)
2765def atleast_3d(*arys):
2766  if len(arys) == 1:
2767    arr = asarray(arys[0])
2768    if ndim(arr) == 0:
2769      arr = expand_dims(arr, axis=(0, 1, 2))
2770    elif ndim(arr) == 1:
2771      arr = expand_dims(arr, axis=(0, 2))
2772    elif ndim(arr) == 2:
2773      arr = expand_dims(arr, axis=2)
2774    return arr
2775  else:
2776    return [atleast_3d(arr) for arr in arys]
2777
2778
2779@_wraps(np.array)
2780def array(object, dtype=None, copy=True, order="K", ndmin=0):
2781  if order is not None and order != "K":
2782    raise NotImplementedError("Only implemented for order='K'")
2783  lax._check_user_dtype_supported(dtype, "array")
2784  dtype = dtype and dtypes.canonicalize_dtype(dtype)
2785
2786  if _can_call_numpy_array(object):
2787    object = _np_array(object, dtype=dtype, ndmin=ndmin, copy=False)
2788  assert type(object) not in dtypes.python_scalar_dtypes
2789
2790  if type(object) is np.ndarray:
2791    out = _device_put_raw(object)
2792    if dtype: assert _dtype(out) == dtype
2793  elif isinstance(object, (DeviceArray, core.Tracer)):
2794    if isinstance(object, DeviceArray) and copy:
2795      # We perform a copy by bouncing back to the host
2796      # TODO(phawkins): add a device runtime function to copy a buffer
2797      out = _device_put_raw(_np_asarray(object))
2798    else:
2799      out = object
2800  elif isinstance(object, (list, tuple)):
2801    if object:
2802      out = stack([asarray(elt, dtype=dtype) for elt in object])
2803    else:
2804      out = _device_put_raw(_np_array([], dtype=dtype))
2805  else:
2806    try:
2807      view = memoryview(object)
2808    except TypeError:
2809      pass  # `object` does not support the buffer interface.
2810    else:
2811      return array(_np_asarray(view), dtype, copy)
2812
2813    raise TypeError("Unexpected input type for array: {}".format(type(object)))
2814
2815  if dtype and _dtype(out) != dtype:
2816    out = lax.convert_element_type(out, dtype)
2817
2818  if ndmin > ndim(out):
2819    out = lax.broadcast(out, (1,) * (ndmin - ndim(out)))
2820  return out
2821
2822def _can_call_numpy_array(x):
2823  return _all(not isinstance(l, (core.Tracer, DeviceArray))
2824              for l in tree_leaves(x))
2825
2826
2827@_wraps(np.asarray)
2828def asarray(a, dtype=None, order=None):
2829  lax._check_user_dtype_supported(dtype, "asarray")
2830  dtype = dtypes.canonicalize_dtype(dtype) if dtype is not None else dtype
2831  return array(a, dtype=dtype, copy=False, order=order)
2832
2833
2834@_wraps(np.zeros_like)
2835def zeros_like(a, dtype=None, shape=None):
2836  _check_arraylike("zeros_like", a)
2837  lax._check_user_dtype_supported(dtype, "zeros_like")
2838  if np.isscalar(shape):
2839    shape = (shape,)
2840  return lax.full_like(a, 0, dtype, shape)
2841
2842
2843@_wraps(np.ones_like)
2844def ones_like(a, dtype=None, shape=None):
2845  _check_arraylike("ones_like", a)
2846  lax._check_user_dtype_supported(dtype, "ones_like")
2847  if np.isscalar(shape):
2848    shape = (shape,)
2849  return lax.full_like(a, 1, dtype, shape)
2850
2851
2852@_wraps(np.full)
2853def full(shape, fill_value, dtype=None):
2854  lax._check_user_dtype_supported(dtype, "full")
2855  shape = (shape,) if ndim(shape) == 0 else shape
2856  return lax.full(shape, fill_value, dtype)
2857
2858
2859@_wraps(np.full_like)
2860def full_like(a, fill_value, dtype=None, shape=None):
2861  _check_arraylike("full_like", a)
2862  lax._check_user_dtype_supported(dtype, "full_like")
2863  if np.isscalar(shape):
2864    shape = (shape,)
2865  return lax.full_like(a, fill_value, dtype, shape)
2866
2867
2868@_wraps(np.zeros)
2869def zeros(shape, dtype=None):
2870  if isinstance(shape, types.GeneratorType):
2871    raise TypeError("expected sequence object with len >= 0 or a single integer")
2872  lax._check_user_dtype_supported(dtype, "zeros")
2873  dtype = float_ if dtype is None else dtype
2874  shape = (shape,) if ndim(shape) == 0 else shape
2875  return lax.full(shape, 0, dtype)
2876
2877@_wraps(np.ones)
2878def ones(shape, dtype=None):
2879  if isinstance(shape, types.GeneratorType):
2880    raise TypeError("expected sequence object with len >= 0 or a single integer")
2881  lax._check_user_dtype_supported(dtype, "ones")
2882  dtype = float_ if dtype is None else dtype
2883  shape = (shape,) if ndim(shape) == 0 else shape
2884  return lax.full(shape, 1, dtype)
2885
2886
2887@_wraps(np.array_equal)
2888def array_equal(a1, a2, equal_nan=False):
2889  try:
2890    a1, a2 = asarray(a1), asarray(a2)
2891  except Exception:
2892    return False
2893  if shape(a1) != shape(a2):
2894    return False
2895  eq = asarray(a1 == a2)
2896  if equal_nan:
2897    eq = logical_or(eq, logical_and(isnan(a1), isnan(a2)))
2898  return all(eq)
2899
2900
2901@_wraps(np.array_equiv)
2902def array_equiv(a1, a2):
2903  try:
2904    a1, a2 = asarray(a1), asarray(a2)
2905  except Exception:
2906    return False
2907  try:
2908    eq = equal(a1, a2)
2909  except ValueError:
2910    # shapes are not broadcastable
2911    return False
2912  return all(eq)
2913
2914
2915# We can't create uninitialized arrays in XLA; use zeros for empty.
2916empty_like = zeros_like
2917empty = zeros
2918
2919
2920@_wraps(np.eye)
2921def eye(N, M=None, k=0, dtype=None):
2922  lax._check_user_dtype_supported(dtype, "eye")
2923  dtype = float_ if dtype is None else dtype
2924  N = operator.index(N)
2925  M = N if M is None else operator.index(M)
2926  if N < 0 or M < 0:
2927    raise ValueError(f"negative dimensions are not allowed, got {N} and {M}")
2928  k = operator.index(k)
2929  return lax._eye(dtype, (N, M), k)
2930
2931
2932@_wraps(np.identity)
2933def identity(n, dtype=None):
2934  lax._check_user_dtype_supported(dtype, "identity")
2935  return eye(n, dtype=dtype)
2936
2937
2938@_wraps(np.arange)
2939def arange(start, stop=None, step=None, dtype=None):
2940  lax._check_user_dtype_supported(dtype, "arange")
2941  require = partial(core.concrete_or_error, _np_asarray)
2942  msg = "It arose in jax.numpy.arange argument `{}`.".format
2943  if stop is None and step is None:
2944    start = require(start, msg("stop"))
2945    dtype = dtype or _dtype(start)
2946    return lax.iota(dtype, np.ceil(start)) # avoids materializing
2947  else:
2948    start = require(start, msg("start"))
2949    stop = None if stop is None else require(stop, msg("stop"))
2950    step = None if step is None else require(step, msg("step"))
2951    if dtype is None:
2952      dtype = _dtype(start, *(x for x in [stop, step] if x is not None))
2953    return array(np.arange(start, stop=stop, step=step, dtype=dtype))
2954
2955
2956def _wrap_numpy_nullary_function(f):
2957  """Adapts `f` to return a DeviceArray instead of an np.ndarray.
2958
2959  `f` cannot have any non-static array arguments.
2960  """
2961  @_wraps(f, update_doc=False)
2962  def wrapper(*args, **kwargs):
2963    return asarray(f(*args, **kwargs))
2964  return wrapper
2965
2966
2967@_wraps(np.linspace)
2968def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None,
2969             axis: int = 0):
2970  """Implementation of linspace differentiable in start and stop args."""
2971  lax._check_user_dtype_supported(dtype, "linspace")
2972  if num < 0:
2973    raise ValueError("Number of samples, %s, must be non-negative." % num)
2974
2975  dtype = dtype or result_type(start, stop, dtypes.canonicalize_dtype(float_))
2976  computation_dtype = promote_types(dtype, dtypes.canonicalize_dtype(float_))
2977  start = asarray(start, dtype=computation_dtype)
2978  stop = asarray(stop, dtype=computation_dtype)
2979
2980  bounds_shape = list(lax.broadcast_shapes(shape(start), shape(stop)))
2981  broadcast_start = broadcast_to(start, bounds_shape)
2982  broadcast_stop = broadcast_to(stop, bounds_shape)
2983  axis = len(bounds_shape) + axis + 1 if axis < 0 else axis
2984  bounds_shape.insert(axis, 1)
2985  iota_shape = [1,] * len(bounds_shape)
2986  iota_shape[axis] = num
2987  div = (num - 1) if endpoint else num
2988  if num > 1:
2989    delta = lax.convert_element_type(stop - start, computation_dtype) / div
2990    if issubdtype(dtype, integer):
2991      # This is similar to how numpy computes linspace, but it
2992      # can fail to recover the endpoints in float32 arithmetic.
2993      out = (reshape(broadcast_start, bounds_shape) +
2994        reshape(lax.iota(dtype, num), iota_shape) *
2995        reshape(delta, bounds_shape))
2996    else:
2997      # This approach recovers the endpoints with float32 arithmetic,
2998      # but can lead to rounding errors for integer outputs.
2999      step = reshape(lax.iota(computation_dtype, num), iota_shape) / div
3000      out = (reshape(broadcast_start, bounds_shape) * (1 - step) +
3001        reshape(broadcast_stop, bounds_shape) * step)
3002  elif num == 1:
3003    delta = nan if endpoint else stop - start
3004    out = reshape(broadcast_start, bounds_shape)
3005  else: # num == 0 degenerate case, match numpy behavior
3006    empty_shape = list(lax.broadcast_shapes(shape(start), shape(stop)))
3007    empty_shape.insert(axis, 0)
3008    delta = nan
3009    out = reshape(array([], dtype=dtype), empty_shape)
3010  if retstep:
3011    return lax.convert_element_type(out, dtype), delta
3012  else:
3013    return lax.convert_element_type(out, dtype)
3014
3015
3016@_wraps(np.logspace)
3017def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None,
3018             axis: int = 0):
3019  """Implementation of logspace differentiable in start and stop args."""
3020  lax._check_user_dtype_supported(dtype, "logspace")
3021  dtype = dtype or result_type(start, stop, dtypes.canonicalize_dtype(float_))
3022  computation_dtype = promote_types(dtype, dtypes.canonicalize_dtype(float_))
3023  start = asarray(start, dtype=computation_dtype)
3024  stop = asarray(stop, dtype=computation_dtype)
3025  lin = linspace(start, stop, num,
3026                 endpoint=endpoint, retstep=False, dtype=None, axis=axis)
3027  return lax.convert_element_type(power(base, lin), dtype)
3028
3029
3030@_wraps(np.geomspace)
3031def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis: int = 0):
3032  """Implementation of geomspace differentiable in start and stop args."""
3033  lax._check_user_dtype_supported(dtype, "geomspace")
3034  dtype = dtype or result_type(start, stop, dtypes.canonicalize_dtype(float_))
3035  computation_dtype = promote_types(dtype, dtypes.canonicalize_dtype(float_))
3036  start = asarray(start, dtype=computation_dtype)
3037  stop = asarray(stop, dtype=computation_dtype)
3038  # follow the numpy geomspace convention for negative and complex endpoints
3039  signflip = 1 - (1 - sign(real(start))) * (1 - sign(real(stop))) // 2
3040  res = signflip * logspace(log10(signflip * start),
3041                            log10(signflip * stop), num,
3042                            endpoint=endpoint, base=10.0,
3043                            dtype=computation_dtype, axis=0)
3044  if axis != 0:
3045    res = moveaxis(res, 0, axis)
3046  return lax.convert_element_type(res, dtype)
3047
3048
3049@_wraps(np.meshgrid)
3050def meshgrid(*args, **kwargs):
3051  indexing = kwargs.get("indexing", "xy")
3052  sparse = kwargs.get("sparse", False)
3053  copy = kwargs.get("copy", True)
3054  if not copy:
3055    raise ValueError("jax.numpy.meshgrid only supports copy=True")
3056
3057  args = list(args)
3058  if indexing == "xy":
3059    if len(args) >= 2:
3060      args[0], args[1] = args[1], args[0]
3061  elif indexing != "ij":
3062    raise ValueError("Valid values for indexing are 'xy' and 'ij', got {}"
3063                     .format(indexing))
3064
3065  shape = []
3066  for i, a in enumerate(args):
3067    args[i] = a = asarray(a)
3068    if len(a.shape) != 1:
3069      msg = "Arguments to jax.numpy.meshgrid must be 1D, got shape {}"
3070      raise ValueError(msg.format(a.shape))
3071    shape.append(1 if sparse else a.shape[0])
3072
3073  output = []
3074  for i, a in enumerate(args):
3075    a = asarray(a)
3076    s = shape
3077    if sparse:
3078      s = list(s)
3079      s[i] = a.shape[0]
3080    output.append(lax.broadcast_in_dim(a, s, (i,)))
3081
3082  if indexing == "xy" and len(args) >= 2:
3083    output[0], output[1] = output[1], output[0]
3084
3085  return output
3086
3087
3088@_wraps(np.i0)
3089def i0(x):
3090  x = lax.abs(*_promote_args_inexact("i0", x))
3091  return lax.mul(lax.exp(x), lax.bessel_i0e(x))
3092
3093
3094@_wraps(np.ix_)
3095def ix_(*args):
3096  n = len(args)
3097  output = []
3098  for i, a in enumerate(args):
3099    a = asarray(a)
3100    if len(a.shape) != 1:
3101      msg = "Arguments to jax.numpy.ix_ must be 1-dimensional, got shape {}"
3102      raise ValueError(msg.format(a.shape))
3103    if _dtype(a) == bool_:
3104      raise NotImplementedError(
3105        "Boolean arguments to jax.numpy.ix_ are not implemented")
3106    shape = [1] * n
3107    shape[i] = a.shape[0]
3108    if a.size == 0:
3109      # Numpy uses an integer index type for empty arrays.
3110      output.append(lax.full(shape, np.zeros((), np.intp)))
3111    else:
3112      output.append(lax.broadcast_in_dim(a, shape, (i,)))
3113  return tuple(output)
3114
3115
3116@_wraps(np.indices)
3117def indices(dimensions, dtype=int32, sparse=False):
3118  dimensions = tuple(
3119      core.concrete_or_error(int, d, "dimensions argument of jnp.indices")
3120      for d in dimensions)
3121  N = len(dimensions)
3122  output = []
3123  s = dimensions
3124  for i, dim in enumerate(dimensions):
3125    idx = lax.iota(dtype, dim)
3126    if sparse:
3127      s = (1,)*i + (dim,) + (1,)*(N - i - 1)
3128    output.append(lax.broadcast_in_dim(idx, s, (i,)))
3129  if sparse:
3130    return tuple(output)
3131  return stack(output, 0) if output else array([], dtype=dtype)
3132
3133
3134_TOTAL_REPEAT_LENGTH_DOC = """\
3135Jax adds the optional `total_repeat_length` parameter which specifies the total
3136number of repeat, and defaults to sum(repeats). It must be specified for repeat
3137to be compilable. If `sum(repeats)` is larger than the specified
3138`total_repeat_length` the remaining values will be discarded. In the case of
3139`sum(repeats)` being smaller than the specified target length, the final value
3140will be repeated.
3141"""
3142
3143
3144@_wraps(np.repeat, lax_description=_TOTAL_REPEAT_LENGTH_DOC)
3145def repeat(a, repeats, axis: Optional[int] = None, *, total_repeat_length=None):
3146  _check_arraylike("repeat", a)
3147
3148  if axis is None:
3149    a = ravel(a)
3150    axis = 0
3151
3152  # If total_repeat_length is not given, can't compile, use a default.
3153  if total_repeat_length is None:
3154    repeats = core.concrete_or_error(np.array, repeats,
3155      "When jit-compiling jnp.repeat, the total number of repeats must be static. "
3156      "To fix this, either specify a static value for `repeats`, or pass a static "
3157      "value to `total_repeat_length`.")
3158
3159    # Fast path for when repeats is a scalar.
3160    if np.ndim(repeats) == 0 and ndim(a) != 0:
3161      input_shape = a.shape
3162      aux_axis = axis if axis < 0 else axis + 1
3163      a = expand_dims(a, aux_axis)
3164      reps = [1] * len(a.shape)
3165      reps[aux_axis] = repeats
3166      a = tile(a, reps)
3167      result_shape = list(input_shape)
3168      result_shape[axis] *= repeats
3169      return reshape(a, result_shape)
3170
3171    repeats = np.ravel(repeats)
3172    if ndim(a) != 0:
3173      repeats = np.broadcast_to(repeats, [a.shape[axis]])
3174    total_repeat_length = np.sum(repeats)
3175  else:
3176    repeats = ravel(repeats)
3177    if ndim(a) != 0:
3178      repeats = broadcast_to(repeats, [a.shape[axis]])
3179
3180  # Special case when a is a scalar.
3181  if ndim(a) == 0:
3182    if repeats.shape == (1,):
3183      return full([total_repeat_length], a)
3184    else:
3185      raise ValueError('`repeat` with a scalar parameter `a` is only '
3186      'implemented for scalar values of the parameter `repeats`.')
3187
3188  # Special case if total_repeat_length is zero.
3189  if total_repeat_length == 0:
3190    result_shape = list(a.shape)
3191    result_shape[axis] = 0
3192    return reshape(array([], dtype=a.dtype), result_shape)
3193
3194  # If repeats is on a zero sized axis, then return the array.
3195  if a.shape[axis] == 0:
3196    return a
3197
3198  # This implementation of repeat avoid having to instantiate a large.
3199  # intermediate tensor.
3200
3201  # Modify repeats from e.g. [1,2,0,5] -> [0,1,2,0] for exclusive repeat.
3202  exclusive_repeats = roll(repeats, shift=1).at[0].set(0)
3203  # Cumsum to get indices of new number in repeated tensor, e.g. [0, 1, 3, 3]
3204  scatter_indices = cumsum(exclusive_repeats)
3205  # Scatter these onto a zero buffer, e.g. [1,1,0,2,0,0,0,0]
3206  block_split_indicators = ops.index_add(
3207      x=zeros([total_repeat_length], dtype=int32),
3208      idx=scatter_indices,
3209      y=1)
3210  # Cumsum again to get scatter indices for repeat, e.g. [0,1,1,3,3,3,3,3]
3211  gather_indices = cumsum(block_split_indicators) - 1
3212  return take(a, gather_indices, axis=axis)
3213
3214
3215@_wraps(np.tri)
3216def tri(N, M=None, k=0, dtype=None):
3217  lax._check_user_dtype_supported(dtype, "tri")
3218  M = M if M is not None else N
3219  dtype = dtype or float32
3220  return lax._tri(dtype, (N, M), k)
3221
3222
3223@_wraps(np.tril)
3224def tril(m, k=0):
3225  _check_arraylike("tril", m)
3226  m_shape = shape(m)
3227  if len(m_shape) < 2:
3228    raise ValueError("Argument to jax.numpy.tril must be at least 2D")
3229  mask = tri(*m_shape[-2:], k=k, dtype=bool)
3230  return lax.select(lax.broadcast(mask, m_shape[:-2]), m, zeros_like(m))
3231
3232
3233@_wraps(np.triu, update_doc=False)
3234def triu(m, k=0):
3235  _check_arraylike("triu", m)
3236  m_shape = shape(m)
3237  if len(m_shape) < 2:
3238    raise ValueError("Argument to jax.numpy.triu must be at least 2D")
3239  mask = tri(*m_shape[-2:], k=k - 1, dtype=bool)
3240  return lax.select(lax.broadcast(mask, m_shape[:-2]), zeros_like(m), m)
3241
3242
3243@_wraps(np.trace)
3244def trace(a, offset=0, axis1: int = 0, axis2: int = 1, dtype=None, out=None):
3245  _check_arraylike("trace", a)
3246  if out is not None:
3247    raise NotImplementedError("The 'out' argument to jnp.trace is not supported.")
3248  lax._check_user_dtype_supported(dtype, "trace")
3249
3250  axis1 = _canonicalize_axis(axis1, ndim(a))
3251  axis2 = _canonicalize_axis(axis2, ndim(a))
3252
3253  a_shape = shape(a)
3254  if dtype is None:
3255    dtype = _dtype(a)
3256    if issubdtype(dtype, integer):
3257      default_int = dtypes.canonicalize_dtype(np.int_)
3258      if iinfo(dtype).bits < iinfo(default_int).bits:
3259        dtype = default_int
3260
3261  # Move the axis? dimensions to the end.
3262  perm = [i for i in range(len(a_shape)) if i != axis1 and i != axis2]
3263  perm = perm + [axis1, axis2]
3264  a = lax.transpose(a, perm)
3265
3266  # Mask out the diagonal and reduce.
3267  a = where(eye(a_shape[axis1], a_shape[axis2], k=offset, dtype=bool),
3268            a, zeros_like(a))
3269  return sum(a, axis=(-2, -1), dtype=dtype)
3270
3271
3272def _wrap_indices_function(f):
3273  @_wraps(f, update_doc=False)
3274  def wrapper(*args, **kwargs):
3275    return tuple(asarray(x) for x in f(*args, **kwargs))
3276  return wrapper
3277
3278tril_indices = _wrap_indices_function(np.tril_indices)
3279triu_indices = _wrap_indices_function(np.triu_indices)
3280mask_indices = _wrap_indices_function(np.mask_indices)
3281
3282
3283@_wraps(np.triu_indices_from)
3284def triu_indices_from(arr, k=0):
3285  return triu_indices(arr.shape[-2], k=k, m=arr.shape[-1])
3286
3287
3288@_wraps(np.tril_indices_from)
3289def tril_indices_from(arr, k=0):
3290  return tril_indices(arr.shape[-2], k=k, m=arr.shape[-1])
3291
3292
3293@_wraps(np.diag_indices)
3294def diag_indices(n, ndim=2):
3295  if n < 0:
3296    raise ValueError("n argument to diag_indices must be nonnegative, got {}"
3297                     .format(n))
3298  if ndim < 0:
3299    raise ValueError("ndim argument to diag_indices must be nonnegative, got {}"
3300                     .format(ndim))
3301  return (lax.iota(int_, n),) * ndim
3302
3303@_wraps(np.diag_indices_from)
3304def diag_indices_from(arr):
3305  _check_arraylike("diag_indices_from", arr)
3306  if not arr.ndim >= 2:
3307    raise ValueError("input array must be at least 2-d")
3308
3309  if len(set(arr.shape)) != 1:
3310    raise ValueError("All dimensions of input must be of equal length")
3311
3312  return diag_indices(arr.shape[0], ndim=arr.ndim)
3313
3314@_wraps(np.diagonal)
3315def diagonal(a, offset=0, axis1: int = 0, axis2: int = 1):
3316  _check_arraylike("diagonal", a)
3317  a_shape = shape(a)
3318  a_ndims = len(a_shape)
3319
3320  # Move the two dimensions to the end.
3321  axis1 = _canonicalize_axis(axis1, a_ndims)
3322  axis2 = _canonicalize_axis(axis2, a_ndims)
3323  perm = [i for i in range(a_ndims) if i != axis1 and i != axis2]
3324  perm = perm + [axis1, axis2]
3325  a = lax.transpose(a, perm)
3326
3327  # Mask out the diagonal and reduce over one of the axes
3328  a = where(eye(a_shape[axis1], a_shape[axis2], k=offset, dtype=bool),
3329            a, zeros_like(a))
3330  reduce_axis = -2 if offset < 0 else -1
3331  d = sum(a, axis=reduce_axis, dtype=_dtype(a))
3332
3333  # Slice out the correct diagonal size.
3334  diag_size = _max(0, _min(a_shape[axis1] + _min(offset, 0),
3335                           a_shape[axis2] - _max(offset, 0)))
3336  return lax.slice_in_dim(d, 0, diag_size, axis=-1)
3337
3338
3339@_wraps(np.diag)
3340def diag(v, k=0):
3341  _check_arraylike("diag", v)
3342  v_shape = shape(v)
3343  if len(v_shape) == 1:
3344    zero = lambda x: lax.full_like(x, shape=(), fill_value=0)
3345    n = v_shape[0] + _abs(k)
3346    v = lax.pad(v, zero(v), ((_max(0, k), _max(0, -k), 0),))
3347    return where(eye(n, k=k, dtype=bool), v, zeros_like(v))
3348  elif len(v_shape) == 2:
3349    return diagonal(v, offset=k)
3350  else:
3351    raise ValueError("diag input must be 1d or 2d")
3352
3353_SCALAR_VALUE_DOC="""\
3354This differs from np.diagflat for some scalar values of v,
3355jax always returns a two-dimensional array, whereas numpy may
3356return a scalar depending on the type of v.
3357"""
3358
3359@_wraps(np.diagflat, lax_description=_SCALAR_VALUE_DOC)
3360def diagflat(v, k=0):
3361  _check_arraylike("diagflat", v)
3362  v = ravel(v)
3363  v_length = len(v)
3364  adj_length = v_length + _abs(k)
3365  res = zeros(adj_length*adj_length, dtype=v.dtype)
3366  i = arange(0, adj_length-_abs(k))
3367  if (k >= 0):
3368    fi = i+k+i*adj_length
3369  else:
3370    fi = i+(i-k)*adj_length
3371  res = ops.index_update(res, ops.index[fi], v)
3372  res = res.reshape(adj_length,adj_length)
3373  return res
3374
3375
3376@_wraps(np.polyval)
3377def polyval(p, x):
3378  if isinstance(p, np.poly1d):
3379    p = np.asarray(p)
3380  if isinstance(x, np.poly1d):
3381    y = 0
3382  else:
3383    y = zeros_like(x)
3384  for i in range(len(p)):
3385    y = y * x + p[i]
3386  return y
3387
3388@_wraps(np.polyadd)
3389def polyadd(a1, a2):
3390  a1 = asarray(a1)
3391  a2 = asarray(a2)
3392
3393  if a2.shape[0] <= a1.shape[0]:
3394    return a1.at[-a2.shape[0]:].add(a2)
3395  else:
3396    return a2.at[-a1.shape[0]:].add(a1)
3397
3398
3399@_wraps(np.polyder)
3400def polyder(p, m=1):
3401  p = asarray(p)
3402  if m < 0:
3403    raise ValueError("Order of derivative must be positive")
3404  if m == 0:
3405    return p
3406  if m % 1:
3407    raise ValueError("m must be an integer")
3408  coeff = (arange(len(p), m, -1) - 1 - arange(m)[:, newaxis]).prod(0)
3409  return p[:-m] * coeff
3410
3411@_wraps(np.trim_zeros)
3412def trim_zeros(filt, trim='fb'):
3413  filt = core.concrete_or_error(asarray, filt,
3414    "Error arose in the `filt` argument of trim_zeros()")
3415  nz = asarray(filt) == 0
3416  if all(nz):
3417    return empty(0, _dtype(filt))
3418  start = argmin(nz) if 'f' in trim.lower() else 0
3419  end = argmin(nz[::-1]) if 'b' in trim.lower() else 0
3420  return filt[start:len(filt) - end]
3421
3422_LEADING_ZEROS_DOC="""\
3423Setting trim_leading_zeros=True makes the output match that of numpy.
3424But prevents the function from being able to be used in compiled code.
3425"""
3426
3427@_wraps(np.polymul, lax_description=_LEADING_ZEROS_DOC)
3428def polymul(a1, a2, *, trim_leading_zeros=False):
3429  if isinstance(a1, np.poly1d):
3430    a1 = asarray(a1)
3431  if isinstance(a2, np.poly1d):
3432    a2 = asarray(a2)
3433  if trim_leading_zeros and (len(a1) > 1 or len(a2) > 1):
3434    a1, a2 = trim_zeros(a1, trim='f'), trim_zeros(a2, trim='f')
3435  if len(a1) == 0:
3436    a1 = asarray([0.])
3437  if len(a2) == 0:
3438    a2 = asarray([0.])
3439  val = convolve(a1, a2, mode='full')
3440  return val
3441
3442@_wraps(np.polysub)
3443def polysub(a1, a2):
3444  return polyadd(asarray(a1), -asarray(a2))
3445
3446
3447@_wraps(np.append)
3448def append(arr, values, axis: Optional[int] = None):
3449  if axis is None:
3450    return concatenate([ravel(arr), ravel(values)], 0)
3451  else:
3452    return concatenate([arr, values], axis=axis)
3453
3454
3455@_wraps(np.apply_along_axis)
3456def apply_along_axis(func1d, axis: int, arr, *args, **kwargs):
3457  num_dims = ndim(arr)
3458  axis = _canonicalize_axis(axis, num_dims)
3459  func = lambda arr: func1d(arr, *args, **kwargs)
3460  for i in range(1, num_dims - axis):
3461    func = jax.vmap(func, in_axes=i, out_axes=-1)
3462  for i in range(axis):
3463    func = jax.vmap(func, in_axes=0, out_axes=0)
3464  return func(arr)
3465
3466
3467@_wraps(np.apply_over_axes)
3468def apply_over_axes(func, a, axes):
3469  for axis in axes:
3470    b = func(a, axis=axis)
3471    if b.ndim == a.ndim:
3472      a = b
3473    elif b.ndim == a.ndim - 1:
3474      a = expand_dims(b, axis)
3475    else:
3476      raise ValueError("function is not returning an array of the correct shape")
3477  return a
3478
3479
3480### Tensor contraction operations
3481
3482
3483@_wraps(np.dot, lax_description=_PRECISION_DOC)
3484def dot(a, b, *, precision=None):  # pylint: disable=missing-docstring
3485  _check_arraylike("dot", a, b)
3486  a, b = _promote_dtypes(a, b)
3487  a_ndim, b_ndim = ndim(a), ndim(b)
3488  if a_ndim == 0 or b_ndim == 0:
3489    return lax.mul(a, b)
3490  if _max(a_ndim, b_ndim) <= 2:
3491    return lax.dot(a, b, precision=precision)
3492
3493  if b_ndim == 1:
3494    contract_dims = ((a_ndim - 1,), (0,))
3495  else:
3496    contract_dims = ((a_ndim - 1,), (b_ndim - 2,))
3497  batch_dims = ((), ())
3498  return lax.dot_general(a, b, (contract_dims, batch_dims), precision)
3499
3500
3501@_wraps(np.matmul, lax_description=_PRECISION_DOC)
3502def matmul(a, b, *, precision=None):  # pylint: disable=missing-docstring
3503  _check_arraylike("matmul", a, b)
3504  for i, x in enumerate((a, b)):
3505    if ndim(x) < 1:
3506      msg = (f"matmul input operand {i} must have ndim at least 1, "
3507             f"but it has ndim {ndim(x)}")
3508      raise ValueError(msg)
3509
3510  a, b = _promote_dtypes(a, b)
3511
3512  a_is_mat, b_is_mat = (ndim(a) > 1), (ndim(b) > 1)
3513  a_batch_dims = shape(a)[:-2] if a_is_mat else ()
3514  b_batch_dims = shape(b)[:-2] if b_is_mat else ()
3515  num_batch_dims = _max(len(a_batch_dims), len(b_batch_dims))
3516  a_batch_dims = (None,) * (num_batch_dims - len(a_batch_dims)) + a_batch_dims
3517  b_batch_dims = (None,) * (num_batch_dims - len(b_batch_dims)) + b_batch_dims
3518
3519  # Dimensions to squeeze from the inputs.
3520  a_squeeze = []
3521  b_squeeze = []
3522
3523  # Positions of batch dimensions in squeezed inputs.
3524  a_batch = []
3525  b_batch = []
3526
3527  # Desired index in final output of each kind of dimension, in the order that
3528  # lax.dot_general will emit them.
3529  idx_batch = []
3530  idx_a_other = []  # other = non-batch, non-contracting.
3531  idx_b_other = []
3532  for i, (ba, bb) in enumerate(zip(a_batch_dims, b_batch_dims)):
3533    if ba is None:
3534      idx_b_other.append(i)
3535    elif bb is None:
3536      idx_a_other.append(i)
3537    elif ba == 1:
3538      idx_b_other.append(i)
3539      a_squeeze.append(len(idx_batch) + len(idx_a_other) + len(a_squeeze))
3540    elif bb == 1:
3541      idx_a_other.append(i)
3542      b_squeeze.append(len(idx_batch) + len(idx_b_other) + len(b_squeeze))
3543    elif ba == bb:
3544      a_batch.append(len(idx_batch) + len(idx_a_other))
3545      b_batch.append(len(idx_batch) + len(idx_b_other))
3546      idx_batch.append(i)
3547    else:
3548      raise ValueError("Incompatible shapes for matmul arguments: {} and {}"
3549                       .format(shape(a), shape(b)))
3550
3551  if a_is_mat: idx_a_other.append(num_batch_dims)
3552  if b_is_mat: idx_b_other.append(num_batch_dims + a_is_mat)
3553  perm = np.argsort(np.concatenate([idx_batch, idx_a_other, idx_b_other]))
3554
3555  a = lax.squeeze(a, tuple(a_squeeze))
3556  b = lax.squeeze(b, tuple(b_squeeze))
3557  out = lax.dot_general(
3558    a, b, (((ndim(a) - 1,), (ndim(b) - 1 - b_is_mat,)), (a_batch, b_batch)),
3559    precision=precision)
3560  return lax.transpose(out, perm)
3561
3562
3563@_wraps(np.vdot, lax_description=_PRECISION_DOC)
3564def vdot(a, b, *, precision=None):
3565  _check_arraylike("vdot", a, b)
3566  if issubdtype(_dtype(a), complexfloating):
3567    a = conj(a)
3568  return dot(a.ravel(), b.ravel(), precision=precision)
3569
3570
3571@_wraps(np.tensordot, lax_description=_PRECISION_DOC)
3572def tensordot(a, b, axes=2, *, precision=None):
3573  _check_arraylike("tensordot", a, b)
3574  a_ndim = ndim(a)
3575  b_ndim = ndim(b)
3576
3577  a, b = _promote_dtypes(a, b)
3578  if type(axes) is int:
3579    if axes > _min(a_ndim, b_ndim):
3580      msg = "Number of tensordot axes (axes {}) exceeds input ranks ({} and {})"
3581      raise TypeError(msg.format(axes, a.shape, b.shape))
3582    contracting_dims = tuple(range(a_ndim - axes, a_ndim)), tuple(range(axes))
3583  elif type(axes) in (list, tuple) and len(axes) == 2:
3584    ax1, ax2 = axes
3585    if type(ax1) == type(ax2) == int:
3586      contracting_dims = ((_canonicalize_axis(ax1, a_ndim),),
3587                          (_canonicalize_axis(ax2, b_ndim),))
3588    elif type(ax1) in (list, tuple) and type(ax2) in (list, tuple):
3589      if len(ax1) != len(ax2):
3590        msg = "tensordot requires axes lists to have equal length, got {} and {}."
3591        raise TypeError(msg.format(ax1, ax2))
3592      contracting_dims = (tuple(_canonicalize_axis(i, a_ndim) for i in ax1),
3593                          tuple(_canonicalize_axis(i, b_ndim) for i in ax2))
3594    else:
3595      msg = ("tensordot requires both axes lists to be either ints, tuples or "
3596             "lists, got {} and {}")
3597      raise TypeError(msg.format(ax1, ax2))
3598  else:
3599    msg = ("tensordot axes argument must be an int, a pair of ints, or a pair "
3600           "of lists/tuples of ints.")
3601    raise TypeError(msg)
3602  return lax.dot_general(a, b, (contracting_dims, ((), ())),
3603                         precision=precision)
3604
3605
3606@_wraps(np.einsum, lax_description=_PRECISION_DOC)
3607def einsum(*operands, out=None, optimize='greedy', precision=None,
3608           _use_xeinsum=False):
3609  if out is not None:
3610    raise NotImplementedError("The 'out' argument to jnp.einsum is not supported.")
3611
3612  if (_use_xeinsum or isinstance(operands[0], str) and '{' in operands[0] and
3613      len(operands[1:]) == 2):
3614    return lax.xeinsum(*operands)
3615
3616  optimize = 'greedy' if optimize is True else optimize
3617  # using einsum_call=True here is an internal api for opt_einsum
3618  operands, contractions = opt_einsum.contract_path(
3619      *operands, einsum_call=True, use_blas=True, optimize=optimize)
3620  contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions)
3621  return _einsum(operands, contractions, precision)
3622
3623@_wraps(np.einsum_path)
3624def einsum_path(subscripts, *operands, optimize='greedy'):
3625  # using einsum_call=True here is an internal api for opt_einsum
3626  return opt_einsum.contract_path(subscripts, *operands, optimize=optimize)
3627
3628def _removechars(s, chars):
3629  return s.translate(str.maketrans(dict.fromkeys(chars)))
3630
3631@partial(jit, static_argnums=(1, 2))
3632def _einsum(operands: Sequence,
3633            contractions: Sequence[Tuple[Tuple[int, ...], FrozenSet[str], str]],
3634            precision):
3635  operands = list(_promote_dtypes(*operands))
3636  def sum(x, axes):
3637    return lax.reduce(x, np.array(0, x.dtype),
3638                      lax.add if x.dtype != bool_ else lax.bitwise_or, axes)
3639
3640  def sum_uniques(operand, names, uniques):
3641    if uniques:
3642      axes = [names.index(name) for name in uniques]
3643      operand = sum(operand, axes)
3644      names = _removechars(names, uniques)
3645    return operand, names
3646
3647  def sum_repeats(operand, names, counts, keep_names):
3648    for name, count in counts.items():
3649      if count > 1:
3650        axes = [i for i, n in enumerate(names) if n == name]
3651        eye = lax._delta(operand.dtype, operand.shape, axes)
3652        if name not in keep_names:
3653          operand = sum(operand * eye, axes)
3654          names = names.replace(name, '')
3655        else:
3656          operand = sum(operand * eye, axes[:-1])
3657          names = names.replace(name, '', count - 1)
3658    return operand, names
3659
3660  def filter_singleton_dims(operand, names, other_shape, other_names):
3661    s = shape(operand)
3662    new_shape = []
3663    new_names = []
3664    for i, d in enumerate(names):
3665      other_i = other_names.find(d)
3666      if s[i] != 1 or other_i == -1 or other_shape[other_i] == 1:
3667        new_shape.append(s[i])
3668        new_names.append(d)
3669    return reshape(operand, tuple(new_shape)), "".join(new_names)
3670
3671  for operand_indices, contracted_names_set, einstr in contractions:
3672    contracted_names = sorted(contracted_names_set)
3673    input_str, result_names = einstr.split('->')
3674    input_names = input_str.split(',')
3675
3676    # switch on the number of operands to be processed in this loop iteration.
3677    # every case here sets 'operand' and 'names'.
3678    if len(operand_indices) == 1:
3679      operand = operands.pop(operand_indices[0])
3680      names, = input_names
3681      counts = collections.Counter(names)
3682
3683      # sum out unique contracted indices with a single reduce-sum
3684      uniques = [name for name in contracted_names if counts[name] == 1]
3685      operand, names = sum_uniques(operand, names, uniques)
3686
3687      # for every repeated index, do a contraction against an identity matrix
3688      operand, names = sum_repeats(operand, names, counts, result_names)
3689
3690    elif len(operand_indices) == 2:
3691      lhs, rhs = map(operands.pop, operand_indices)
3692      lhs_names, rhs_names = input_names
3693
3694      # handle cases where one side of a contracting or batch dimension is 1
3695      # but its counterpart is not.
3696      lhs, lhs_names = filter_singleton_dims(lhs, lhs_names, shape(rhs),
3697                                             rhs_names)
3698      rhs, rhs_names = filter_singleton_dims(rhs, rhs_names, shape(lhs),
3699                                             lhs_names)
3700
3701      lhs_counts = collections.Counter(lhs_names)
3702      rhs_counts = collections.Counter(rhs_names)
3703
3704      # sum out unique contracted indices in lhs and rhs
3705      lhs_uniques = [name for name in contracted_names
3706                     if lhs_counts[name] == 1 and rhs_counts[name] == 0]
3707      lhs, lhs_names = sum_uniques(lhs, lhs_names, lhs_uniques)
3708
3709      rhs_uniques = [name for name in contracted_names
3710                     if rhs_counts[name] == 1 and lhs_counts[name] == 0]
3711      rhs, rhs_names = sum_uniques(rhs, rhs_names, rhs_uniques)
3712
3713      # for every repeated index, contract against an identity matrix
3714      lhs, lhs_names = sum_repeats(lhs, lhs_names, lhs_counts,
3715                                   result_names + rhs_names)
3716      rhs, rhs_names = sum_repeats(rhs, rhs_names, rhs_counts,
3717                                   result_names + lhs_names)
3718
3719      lhs_or_rhs_names = set(lhs_names) | set(rhs_names)
3720      contracted_names = [x for x in contracted_names if x in lhs_or_rhs_names]
3721      lhs_and_rhs_names = set(lhs_names) & set(rhs_names)
3722      batch_names = [x for x in result_names if x in lhs_and_rhs_names]
3723
3724      lhs_batch, rhs_batch = unzip2((lhs_names.find(n), rhs_names.find(n))
3725                                    for n in batch_names)
3726
3727      # NOTE(mattjj): this can fail non-deterministically in python3, maybe
3728      # due to opt_einsum
3729      assert _all(
3730        name in lhs_names and name in rhs_names and
3731        lhs.shape[lhs_names.index(name)] == rhs.shape[rhs_names.index(name)]
3732        for name in contracted_names)
3733
3734      # contract using lax.dot_general
3735      batch_names_str = ''.join(batch_names)
3736      lhs_cont, rhs_cont = unzip2((lhs_names.index(n), rhs_names.index(n))
3737                                  for n in contracted_names)
3738      dimension_numbers = ((lhs_cont, rhs_cont), (lhs_batch, rhs_batch))
3739      operand = lax.dot_general(lhs, rhs, dimension_numbers, precision)
3740      deleted_names = batch_names_str + ''.join(contracted_names)
3741      names = (batch_names_str + _removechars(lhs_names, deleted_names)
3742               + _removechars(rhs_names, deleted_names))
3743    else:
3744      raise NotImplementedError  # if this is actually reachable, open an issue!
3745
3746    # the resulting 'operand' with axis labels 'names' should be a permutation
3747    # of the desired result
3748    assert len(names) == len(result_names) == len(set(names))
3749    assert set(names) == set(result_names)
3750    if names != result_names:
3751      perm = tuple([names.index(name) for name in result_names])
3752      operand = lax.transpose(operand, perm)
3753    operands.append(operand)  # used in next iteration
3754
3755  return operands[0]
3756
3757
3758def _movechars(s, src, dst):
3759  """Helper for einsum string munging, like moveaxis on identifier strings."""
3760  chars = [c for i, c in enumerate(s) if i not in src]
3761  for i, j in sorted(zip(dst, src)):
3762    chars.insert(i, s[j])
3763  return ''.join(chars)
3764
3765
3766@_wraps(np.inner, lax_description=_PRECISION_DOC)
3767def inner(a, b, *, precision=None):
3768  if ndim(a) == 0 or ndim(b) == 0:
3769    return a * b
3770  return tensordot(a, b, (-1, -1), precision=precision)
3771
3772
3773@_wraps(np.outer)
3774def outer(a, b, out=None):
3775  if out is not None:
3776    raise NotImplementedError("The 'out' argument to jnp.outer is not supported.")
3777  a, b = _promote_dtypes(a, b)
3778  return ravel(a)[:, None] * ravel(b)[None, :]
3779
3780@partial(jit, static_argnums=(2, 3, 4))
3781def _cross(a, b, axisa, axisb, axisc):
3782  a = moveaxis(a, axisa, -1)
3783  b = moveaxis(b, axisb, -1)
3784
3785  if a.shape[-1] not in (2, 3) or b.shape[-1] not in (2, 3):
3786    raise ValueError("Dimension must be either 2 or 3 for cross product")
3787
3788  if a.shape[-1] == 2 and b.shape[-1] == 2:
3789    return a[..., 0] * b[..., 1] - a[..., 1] * b[..., 0]
3790
3791  a0 = a[..., 0]
3792  a1 = a[..., 1]
3793  a2 = a[..., 2] if a.shape[-1] == 3 else zeros_like(a0)
3794  b0 = b[..., 0]
3795  b1 = b[..., 1]
3796  b2 = b[..., 2] if b.shape[-1] == 3 else zeros_like(b0)
3797  c = array([a1 * b2 - a2 * b1, a2 * b0 - a0 * b2, a0 * b1 - a1 * b0])
3798  return moveaxis(c, 0, axisc)
3799
3800@_wraps(np.cross)
3801def cross(a, b, axisa: int = -1, axisb: int = -1, axisc: int = -1,
3802          axis: Optional[int] = None):
3803  if axis is not None:
3804    axisa = axis
3805    axisb = axis
3806    axisc = axis
3807  return _cross(a, b, axisa, axisb, axisc)
3808
3809@_wraps(np.kron)
3810def kron(a, b):
3811  a, b = _promote_dtypes(a, b)
3812  if ndim(a) < ndim(b):
3813    a = reshape(a, (1,) * (ndim(b) - ndim(a)) + shape(a))
3814  elif ndim(b) < ndim(a):
3815    b = reshape(b, (1,) * (ndim(a) - ndim(b)) + shape(b))
3816  a_reshaped = reshape(a, [i for d in shape(a) for i in (d, 1)])
3817  b_reshaped = reshape(b, [i for d in shape(b) for i in (1, d)])
3818  out_shape = tuple(np.multiply(shape(a), shape(b)))
3819  return reshape(lax.mul(a_reshaped, b_reshaped), out_shape)
3820
3821
3822@_wraps(np.vander)
3823def vander(x, N=None, increasing=False):
3824  x = asarray(x)
3825  dtype = _dtype(x)
3826  if ndim(x) != 1:
3827    raise ValueError("x must be a one-dimensional array")
3828  x_shape = shape(x)
3829  N = N or x_shape[0]
3830  if N < 0:
3831    raise ValueError("N must be nonnegative")
3832
3833  iota = lax.iota(dtype, N)
3834  if not increasing:
3835    iota = lax.sub(lax._const(iota, N - 1), iota)
3836
3837  return power(x[..., None], iota)
3838
3839
3840### Misc
3841
3842
3843@_wraps(np.argwhere)
3844def argwhere(a):
3845  result = transpose(vstack(nonzero(a)))
3846  if ndim(a) == 0:
3847    return result[:0].reshape(result.shape[0], 0)
3848  return result.reshape(result.shape[0], ndim(a))
3849
3850
3851@_wraps(np.argmax)
3852def argmax(a, axis: Optional[int] = None, out=None):
3853  _check_arraylike("argmax", a)
3854  if out is not None:
3855    raise NotImplementedError("The 'out' argument to jnp.argmax is not supported.")
3856  if axis is None:
3857    a = ravel(a)
3858    axis = 0
3859  if a.shape[axis] == 0:
3860    raise ValueError("attempt to get argmax of an empty sequence")
3861  return lax.argmax(a, _canonicalize_axis(axis, a.ndim), int64)
3862
3863@_wraps(np.argmin)
3864def argmin(a, axis: Optional[int] = None, out=None):
3865  _check_arraylike("argmin", a)
3866  if out is not None:
3867    raise NotImplementedError("The 'out' argument to jnp.argmin is not supported.")
3868  if axis is None:
3869    a = ravel(a)
3870    axis = 0
3871  if a.shape[axis] == 0:
3872    raise ValueError("attempt to get argmin of an empty sequence")
3873  return lax.argmin(a, _canonicalize_axis(axis, a.ndim), int64)
3874
3875
3876_NANARG_DOC = """\
3877Warning: jax.numpy.arg{} returns -1 for all-NaN slices and does not raise
3878an error.
3879"""
3880
3881@_wraps(np.nanargmax, lax_description=_NANARG_DOC.format("max"))
3882def nanargmax(a, axis: Optional[int] = None):
3883  _check_arraylike("nanargmax", a)
3884  if not issubdtype(_dtype(a), inexact):
3885    return argmax(a, axis=axis)
3886  nan_mask = isnan(a)
3887  a = where(nan_mask, -inf, a)
3888  res = argmax(a, axis=axis)
3889  return where(all(nan_mask, axis=axis), -1, res)
3890
3891@_wraps(np.nanargmin, lax_description=_NANARG_DOC.format("min"))
3892def nanargmin(a, axis: Optional[int] = None):
3893  _check_arraylike("nanargmin", a)
3894  if not issubdtype(_dtype(a), inexact):
3895    return argmin(a, axis=axis)
3896  nan_mask = isnan(a)
3897  a = where(nan_mask, inf, a)
3898  res = argmin(a, axis=axis)
3899  return where(all(nan_mask, axis=axis), -1, res)
3900
3901
3902@_wraps(np.sort)
3903def sort(a, axis: Optional[int] = -1, kind='quicksort', order=None):
3904  _check_arraylike("sort", a)
3905  if kind != 'quicksort':
3906    warnings.warn("'kind' argument to sort is ignored.")
3907  if order is not None:
3908    raise ValueError("'order' argument to sort is not supported.")
3909
3910  if axis is None:
3911    return lax.sort(a.ravel(), dimension=0)
3912  else:
3913    return lax.sort(a, dimension=_canonicalize_axis(axis, ndim(a)))
3914
3915@_wraps(np.sort_complex)
3916def sort_complex(a):
3917  _check_arraylike("sort_complex", a)
3918  a = lax.sort(a, dimension=0)
3919  return lax.convert_element_type(a, result_type(a, dtypes.canonicalize_dtype(complex_)))
3920
3921@_wraps(np.lexsort)
3922def lexsort(keys, axis=-1):
3923  keys = tuple(keys)
3924  if len(keys) == 0:
3925    raise TypeError("need sequence of keys with len > 0 in lexsort")
3926  if len({shape(key) for key in keys}) > 1:
3927    raise ValueError("all keys need to be the same shape")
3928  if ndim(keys[0]) == 0:
3929    return np.int64(0)
3930  axis = _canonicalize_axis(axis, ndim(keys[0]))
3931  iota = lax.broadcasted_iota(np.int64, shape(keys[0]), axis)
3932  return lax.sort((*keys[::-1], iota), dimension=axis, num_keys=len(keys))[-1]
3933
3934
3935@_wraps(np.argsort)
3936def argsort(a, axis: Optional[int] = -1, kind='quicksort', order=None):
3937  _check_arraylike("argsort", a)
3938  if kind != 'quicksort':
3939    warnings.warn("'kind' argument to argsort is ignored.")
3940  if order is not None:
3941    raise ValueError("'order' argument to argsort is not supported.")
3942
3943  if axis is None:
3944    return argsort(a.ravel(), 0)
3945  else:
3946    axis_num = _canonicalize_axis(axis, ndim(a))
3947    iota = lax.broadcasted_iota(np.int64, shape(a), axis_num)
3948    _, perm = lax.sort_key_val(a, iota, dimension=axis_num)
3949    return perm
3950
3951
3952@_wraps(np.msort)
3953def msort(a):
3954  return sort(a, axis=0)
3955
3956
3957@partial(jit, static_argnums=(2,))
3958def _roll(a, shift, axis):
3959  a = asarray(a)
3960  a_shape = shape(a)
3961  if axis is None:
3962    return lax.reshape(roll(ravel(a), shift, axis=0), a_shape)
3963
3964  a_ndim = len(a_shape)
3965  shift = asarray(shift)
3966  axis = np.asarray(axis)
3967  b_shape = lax.broadcast_shapes(shift.shape, axis.shape, (1,))
3968  if len(b_shape) != 1:
3969    msg = "'shift' and 'axis' arguments to roll must be scalars or 1D arrays"
3970    raise ValueError(msg)
3971
3972  for x, i in zip(broadcast_to(shift, b_shape),
3973                  np.broadcast_to(axis, b_shape)):
3974    i = _canonicalize_axis(i, a_ndim)
3975    x = remainder(x, (a_shape[i] or 1))
3976    a = lax.concatenate((a, a), i)
3977    a = lax.dynamic_slice_in_dim(a, a_shape[i] - x, a_shape[i], axis=i)
3978  return a
3979
3980
3981@_wraps(np.roll)
3982def roll(a, shift, axis: Optional[Union[int, Sequence[int]]] = None):
3983  if isinstance(axis, list):
3984    axis = tuple(axis)
3985  return _roll(a, shift, axis)
3986
3987
3988@_wraps(np.rollaxis)
3989def rollaxis(a, axis: int, start=0):
3990  _check_arraylike("rollaxis", a)
3991  a_ndim = ndim(a)
3992  axis = _canonicalize_axis(axis, a_ndim)
3993  if not (-a_ndim <= start <= a_ndim):
3994    raise ValueError(f"start={start} must satisfy {-a_ndim}<=start<={a_ndim}")
3995  if start < 0:
3996    start += a_ndim
3997  if start > axis:
3998    start -= 1
3999  return moveaxis(a, axis, start)
4000
4001
4002@_wraps(np.packbits)
4003def packbits(a, axis: Optional[int] = None, bitorder='big'):
4004  a = asarray(a)
4005  if not (issubdtype(dtype(a), integer) or issubdtype(dtype(a), bool_)):
4006    raise TypeError('Expected an input array of integer or boolean data type')
4007  if bitorder not in ['little', 'big']:
4008    raise ValueError("'order' must be either 'little' or 'big'")
4009  a = (a > 0).astype('uint8')
4010  bits = arange(8, dtype='uint8')
4011  if bitorder == 'big':
4012    bits = bits[::-1]
4013  if axis is None:
4014    a = ravel(a)
4015    axis = 0
4016  a = swapaxes(a, axis, -1)
4017
4018  remainder = a.shape[-1] % 8
4019  if remainder:
4020    a = pad(a, (a.ndim - 1) * [(0, 0)] + [(0, 8 - remainder)])
4021
4022  a = a.reshape(a.shape[:-1] + (a.shape[-1] // 8, 8))
4023  packed = (a << bits).sum(-1).astype('uint8')
4024  return swapaxes(packed, axis, -1)
4025
4026
4027@_wraps(np.unpackbits)
4028def unpackbits(a, axis: Optional[int] = None, count=None, bitorder='big'):
4029  a = asarray(a)
4030  if dtype(a) != uint8:
4031    raise TypeError("Expected an input array of unsigned byte data type")
4032  if bitorder not in ['little', 'big']:
4033    raise ValueError("'order' must be either 'little' or 'big'")
4034  bits = asarray(1) << arange(8, dtype='uint8')
4035  if bitorder == 'big':
4036    bits = bits[::-1]
4037  if axis is None:
4038    a = a.ravel()
4039    axis = 0
4040  a = swapaxes(a, axis, -1)
4041  unpacked = ((a[..., None] & bits) > 0).astype('uint8')
4042  unpacked = unpacked.reshape(unpacked.shape[:-2] + (-1,))[..., :count]
4043  return swapaxes(unpacked, axis, -1)
4044
4045
4046@_wraps(np.take)
4047def take(a, indices, axis: Optional[int] = None, out=None, mode=None):
4048  if out is not None:
4049    raise NotImplementedError("The 'out' argument to jnp.take is not supported.")
4050
4051  a = asarray(a)
4052  indices = asarray(indices)
4053
4054  if axis is None:
4055    a = ravel(a)
4056    axis_idx = 0
4057  else:
4058    axis_idx = _canonicalize_axis(axis, ndim(a))
4059
4060  if mode == "raise":
4061    # TODO(phawkins): we have no way to report out of bounds errors yet.
4062    raise NotImplementedError("The 'raise' mode to jnp.take is not supported.")
4063  elif mode == "wrap":
4064    indices = mod(indices, _constant_like(indices, a.shape[axis_idx]))
4065  elif mode != "clip" and mode is not None:
4066    raise ValueError("Invalid mode '{}' for np.take".format(mode))
4067
4068  index_dims = len(shape(indices))
4069  slice_sizes = list(shape(a))
4070  slice_sizes[axis_idx] = _min(indices.size, 1)
4071  dnums = lax.GatherDimensionNumbers(
4072    offset_dims=tuple(
4073      list(range(axis_idx)) +
4074      list(range(axis_idx + index_dims, len(a.shape) + index_dims - 1))),
4075    collapsed_slice_dims=(axis_idx,),
4076    start_index_map=(axis_idx,))
4077  return lax.gather(a, indices[..., None], dimension_numbers=dnums,
4078                    slice_sizes=tuple(slice_sizes))
4079
4080
4081def _normalize_index(index, axis_size):
4082  """Normalizes an index value in the range [-N, N) to the range [0, N)."""
4083  if type(axis_size) is Poly:
4084    return index + axis_size if index < 0 else index
4085
4086  return lax.select(
4087    lax.lt(index, _constant_like(index, 0)),
4088    lax.add(index, _constant_like(index, axis_size)),
4089    index)
4090
4091@partial(jit, static_argnums=(2,))
4092def _take_along_axis(arr, indices, axis):
4093  if axis is None:
4094    if ndim(indices) != 1:
4095      msg = "take_along_axis indices must be 1D if axis=None, got shape {}"
4096      raise ValueError(msg.format(indices.shape))
4097    return take_along_axis(arr.ravel(), indices, 0)
4098  rank = ndim(arr)
4099  if rank != ndim(indices):
4100    msg = "indices and arr must have the same number of dimensions; {} vs. {}"
4101    raise ValueError(msg.format(ndim(indices), ndim(arr)))
4102  axis = _canonicalize_axis(axis, rank)
4103
4104  def replace(tup, val):
4105    lst = list(tup)
4106    lst[axis] = val
4107    return tuple(lst)
4108
4109  use_64bit_index = _any([type(d) is Poly or d >= (1 << 31) for d in arr.shape])
4110  index_dtype = int64 if use_64bit_index else int32
4111  indices = lax.convert_element_type(indices, index_dtype)
4112
4113  bcast_shape = lax.broadcast_shapes(replace(arr.shape, 1), replace(indices.shape, 1))
4114  indices = broadcast_to(indices, replace(bcast_shape, indices.shape[axis]))
4115  arr     = broadcast_to(arr,     replace(bcast_shape, arr.shape[axis]))
4116
4117  axis_size = arr.shape[axis]
4118  arr_shape = replace(arr.shape, 1)
4119  idx_shape = indices.shape
4120  out_shape = lax.broadcast_shapes(idx_shape, arr_shape)
4121
4122  index_dims = [i for i, idx in enumerate(idx_shape) if i == axis or idx != 1]
4123
4124  gather_index_shape = tuple(np.array(out_shape)[index_dims]) + (1,)
4125  gather_indices = []
4126  slice_sizes = []
4127  offset_dims = []
4128  start_index_map = []
4129  collapsed_slice_dims = []
4130  j = 0
4131  for i in range(rank):
4132    if i == axis:
4133      indices = _normalize_index(indices, axis_size)
4134      gather_indices.append(lax.reshape(indices, gather_index_shape))
4135      slice_sizes.append(1)
4136      start_index_map.append(i)
4137      collapsed_slice_dims.append(i)
4138      j += 1
4139    elif idx_shape[i] != 1:
4140      iota = lax.iota(_dtype(indices), out_shape[i])
4141      if not config.omnistaging_enabled:
4142        iota = lax.tie_in(arr, iota)
4143      iota = lax.broadcast_in_dim(iota, gather_index_shape, (j,))
4144      gather_indices.append(iota)
4145      slice_sizes.append(1)
4146      start_index_map.append(i)
4147      collapsed_slice_dims.append(i)
4148      j += 1
4149    else:
4150      # If idx_shape[i] == 1, we can just take the entirety of the arr's axis
4151      # and avoid forming an iota index.
4152      offset_dims.append(i)
4153      slice_sizes.append(arr_shape[i])
4154
4155  gather_indices = lax.concatenate(gather_indices, dimension=j)
4156  dnums = lax.GatherDimensionNumbers(
4157    offset_dims=tuple(offset_dims),
4158    collapsed_slice_dims=tuple(collapsed_slice_dims),
4159    start_index_map=tuple(start_index_map))
4160  return lax.gather(arr, gather_indices, dnums, tuple(slice_sizes))
4161
4162
4163@_wraps(getattr(np, "take_along_axis", None), update_doc=False)
4164def take_along_axis(arr, indices, axis: Optional[int]):
4165  _check_arraylike("take_along_axis", arr)
4166  return _take_along_axis(arr, indices, axis)
4167
4168
4169### SetOps
4170
4171@partial(jit, static_argnums=1)
4172def _unique1d_sorted_mask(ar, optional_indices=False):
4173  """
4174  Helper function for unique which is jit-able
4175  """
4176
4177  ar = asarray(ar).flatten()
4178
4179  if optional_indices:
4180    perm = ar.argsort()
4181    aux = ar[perm]
4182  else:
4183    aux = ar.sort()
4184
4185  mask = empty(aux.shape, dtype=bool_)
4186  mask = ops.index_update(mask, ops.index[:1], True)
4187  mask = ops.index_update(mask, ops.index[1:], aux[1:] != aux[:-1])
4188
4189  if optional_indices:
4190    return aux, mask, perm
4191  else:
4192    return aux, mask
4193
4194def _unique1d(ar, return_index=False, return_inverse=False,
4195              return_counts=False):
4196  """
4197  Find the unique elements of an array, ignoring shape.
4198  """
4199
4200  optional_indices = return_index or return_inverse
4201
4202  if optional_indices:
4203    aux, mask, perm = _unique1d_sorted_mask(ar, optional_indices)
4204  else:
4205    aux, mask = _unique1d_sorted_mask(ar, optional_indices)
4206
4207  ret = (aux[mask],)
4208  if return_index:
4209    ret += (perm[mask],)
4210  if return_inverse:
4211    imask = cumsum(mask) - 1
4212    inv_idx = zeros(mask.shape, dtype=dtypes.canonicalize_dtype(int_))
4213    inv_idx = ops.index_update(inv_idx, perm, imask)
4214    ret += (inv_idx,)
4215  if return_counts:
4216    idx = concatenate(nonzero(mask) + (array([mask.size]),))
4217    ret += (diff(idx),)
4218  return ret
4219
4220@_wraps(np.unique)
4221def unique(ar, return_index=False, return_inverse=False,
4222           return_counts=False, axis: Optional[int] = None):
4223  ar = core.concrete_or_error(asarray, ar, "The error arose in jnp.unique()")
4224
4225  if iscomplexobj(ar):
4226    raise NotImplementedError(
4227          "np.unique is not implemented for complex valued arrays")
4228
4229  if axis is None:
4230    ret = _unique1d(ar, return_index, return_inverse, return_counts)
4231    if len(ret) == 1:
4232      return ret[0]
4233    else:
4234      return ret
4235
4236  raise NotImplementedError(
4237        "np.unique is not implemented for the axis argument")
4238
4239### Indexing
4240
4241def _rewriting_take(arr, idx):
4242  # Computes arr[idx].
4243  # All supported cases of indexing can be implemented as an XLA gather,
4244  # followed by an optional reverse and broadcast_in_dim.
4245  arr = asarray(arr)
4246  treedef, static_idx, dynamic_idx = _split_index_for_jit(idx)
4247  return _gather(arr, treedef, static_idx, dynamic_idx)
4248
4249# TODO(phawkins): re-enable jit after fixing excessive recompilation for
4250# slice indexes (e.g., slice(0, 5, None), slice(10, 15, None), etc.).
4251# @partial(jit, static_argnums=(1, 2))
4252def _gather(arr, treedef, static_idx, dynamic_idx):
4253  idx = _merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
4254  indexer = _index_to_gather(shape(arr), idx)  # shared with _scatter_update
4255  y = arr
4256
4257  # Avoid calling gather if the slice shape is empty, both as a fast path and to
4258  # handle cases like zeros(0)[array([], int32)].
4259  if _prod(indexer.slice_shape) == 0:
4260    return zeros_like(y, shape=indexer.slice_shape)
4261
4262  # We avoid generating a gather when indexer.gather_indices.size is empty.
4263  if indexer.gather_indices.size:
4264    y = lax.gather(y, indexer.gather_indices, indexer.dnums,
4265                   indexer.gather_slice_shape)
4266
4267  # Reverses axes with negative strides.
4268  if indexer.reversed_y_dims:
4269    y = lax.rev(y, indexer.reversed_y_dims)
4270
4271  # This adds np.newaxis/None dimensions.
4272  return expand_dims(y, indexer.newaxis_dims)
4273
4274_Indexer = collections.namedtuple("_Indexer", [
4275  # The expected shape of the slice output.
4276  "slice_shape",
4277
4278  # The slice shape to pass to lax.gather().
4279  "gather_slice_shape",
4280
4281  # The gather indices to use.
4282  "gather_indices",
4283
4284  # A GatherDimensionNumbers object describing the gather to perform.
4285  "dnums",
4286
4287  # Slice dimensions that have negative strides, and so must be reversed after
4288  # the gather.
4289  "reversed_y_dims",
4290
4291  # Keep track of any axes created by `newaxis`. These must be inserted for
4292  # gathers and eliminated for scatters.
4293  "newaxis_dims",
4294])
4295
4296def _split_index_for_jit(idx):
4297  """Splits indices into necessarily-static and dynamic parts.
4298
4299  Used to pass indices into `jit`-ted function.
4300  """
4301  # Convert list indices to tuples in cases (deprecated by NumPy.)
4302  idx = _eliminate_deprecated_list_indexing(idx)
4303
4304  # Expand any (concrete) boolean indices. We can then use advanced integer
4305  # indexing logic to handle them.
4306  idx = _expand_bool_indices(idx)
4307
4308  leaves, treedef = tree_flatten(idx)
4309  dynamic = [None] * len(leaves)
4310  static = [None] * len(leaves)
4311  for i, x in enumerate(leaves):
4312    if x is Ellipsis:
4313      static[i] = x
4314    elif isinstance(x, slice):
4315      # slice objects aren't hashable.
4316      static[i] = (x.start, x.stop, x.step)
4317    else:
4318      dynamic[i] = x
4319  return treedef, tuple(static), dynamic
4320
4321def _merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx):
4322  """Recombines indices that were split by _split_index_for_jit."""
4323  idx = []
4324  for s, d in zip(static_idx, dynamic_idx):
4325    if d is not None:
4326      idx.append(d)
4327    elif isinstance(s, tuple):
4328      idx.append(slice(s[0], s[1], s[2]))
4329    else:
4330      idx.append(s)
4331  return treedef.unflatten(idx)
4332
4333def _int(aval):
4334  return not aval.shape and issubdtype(aval.dtype, integer)
4335
4336def _index_to_gather(x_shape, idx, normalize_indices=True):
4337  # Remove ellipses and add trailing slice(None)s.
4338  idx = _canonicalize_tuple_index(len(x_shape), idx)
4339
4340  # Check for advanced indexing:
4341  # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing
4342
4343  # Do the advanced indexing axes appear contiguously? If not, NumPy semantics
4344  # move the advanced axes to the front.
4345  advanced_axes_are_contiguous = False
4346
4347  advanced_indexes = None
4348
4349  # The positions of the advanced indexing axes in `idx`.
4350  idx_advanced_axes = []
4351
4352  # The positions of the advanced indexes in x's shape.
4353  # collapsed, after None axes have been removed. See below.
4354  x_advanced_axes = None
4355
4356  if _is_advanced_int_indexer(idx):
4357    idx_no_nones = [(i, d) for i, d in enumerate(idx) if d is not None]
4358    advanced_pairs = (
4359      (asarray(e), i, j) for j, (i, e) in enumerate(idx_no_nones)
4360      if isscalar(e) or isinstance(e, (Sequence, ndarray)))
4361    if normalize_indices:
4362      advanced_pairs = ((_normalize_index(e, x_shape[j]), i, j)
4363                        for e, i, j in advanced_pairs)
4364    advanced_indexes, idx_advanced_axes, x_advanced_axes = zip(*advanced_pairs)
4365    advanced_axes_are_contiguous = np.all(np.diff(idx_advanced_axes) == 1)
4366
4367  x_axis = 0  # Current axis in x.
4368  y_axis = 0  # Current axis in y, before collapsing. See below.
4369  collapsed_y_axis = 0  # Current axis in y, after collapsing.
4370
4371  # Scatter dimension numbers.
4372  offset_dims = []
4373  collapsed_slice_dims = []
4374  start_index_map = []
4375
4376  use_64bit_index = _any([type(d) is Poly or d >= (1 << 31) for d in x_shape])
4377  index_dtype = int64 if use_64bit_index else int32
4378  gather_indices = np.zeros((0,), dtype=index_dtype)  # use np to save a compilation
4379
4380  # We perform three transformations to y before the scatter op, in order:
4381  # First, y is broadcast to slice_shape. In general `y` only need broadcast to
4382  # the right shape.
4383  slice_shape = []
4384
4385  # Next, y is squeezed to remove newaxis_dims. This removes np.newaxis/`None`
4386  # indices, which the scatter cannot remove itself.
4387  newaxis_dims = []
4388
4389  # Finally, we reverse reversed_y_dims to handle slices with negative strides.
4390  reversed_y_dims = []
4391
4392  gather_slice_shape = []
4393
4394  for idx_pos, i in enumerate(idx):
4395    # Handle the advanced indices here if:
4396    # * the advanced indices were not contiguous and we are the start.
4397    # * we are at the position of the first advanced index.
4398    if (advanced_indexes is not None and
4399        (advanced_axes_are_contiguous and idx_pos == idx_advanced_axes[0] or
4400         not advanced_axes_are_contiguous and idx_pos == 0)):
4401      advanced_indexes = broadcast_arrays(*advanced_indexes)
4402      shape = advanced_indexes[0].shape
4403      ndim = len(shape)
4404      advanced_indexes = [
4405        lax.convert_element_type(lax.reshape(a, shape + (1,)), index_dtype)
4406        for a in advanced_indexes]
4407
4408      # Broadcast gather_indices from [..., k] to [..., 1, 1, ..., 1, k].
4409      gather_indices = lax.broadcast_in_dim(
4410        gather_indices, np.insert(gather_indices.shape, -1, shape),
4411        tuple(range(gather_indices.ndim - 1)) + (gather_indices.ndim + ndim - 1,))
4412      gather_indices = concatenate([gather_indices] + advanced_indexes, -1)
4413      start_index_map.extend(x_advanced_axes)
4414      collapsed_slice_dims.extend(x_advanced_axes)
4415      slice_shape.extend(shape)
4416      y_axis += ndim
4417      collapsed_y_axis += ndim
4418
4419    # Per-index bookkeeping for advanced indexes.
4420    if idx_pos in idx_advanced_axes:
4421      x_axis += 1
4422      gather_slice_shape.append(1)
4423      continue
4424
4425    try:
4426      abstract_i = core.get_aval(i)
4427    except TypeError:
4428      abstract_i = None
4429    # Handle basic int indexes.
4430    if isinstance(abstract_i, (ConcreteArray,ShapedArray)) and _int(abstract_i):
4431      if x_shape[x_axis] == 0:
4432        # XLA gives error when indexing into an axis of size 0
4433        raise IndexError(f"index is out of bounds for axis {x_axis} with size 0")
4434      i = _normalize_index(i, x_shape[x_axis]) if normalize_indices else i
4435      if type(i) is Poly:
4436        # dummy index if i is polynomial, doesn't matter for shape inference
4437        # TODO(mattjj,j-towns,juliuskunze): revise this logic
4438        i = 0
4439      i = lax.convert_element_type(i, index_dtype)
4440      i = broadcast_to(i, tuple(gather_indices.shape[:-1]) + (1,))
4441      gather_indices = concatenate((gather_indices, i), -1)
4442      collapsed_slice_dims.append(x_axis)
4443      gather_slice_shape.append(1)
4444      start_index_map.append(x_axis)
4445      x_axis += 1
4446    # Handle np.newaxis (None)
4447    elif i is None:
4448      slice_shape.append(1)
4449      newaxis_dims.append(y_axis)
4450      y_axis += 1
4451    # Handle slice(None)
4452    elif _is_slice_none(i):
4453      slice_shape.append(x_shape[x_axis])
4454      gather_slice_shape.append(x_shape[x_axis])
4455      offset_dims.append(collapsed_y_axis)
4456      collapsed_y_axis += 1
4457      y_axis += 1
4458      x_axis += 1
4459    # Handle slice index (only static, otherwise an error is raised)
4460    elif isinstance(i, slice):
4461      if not _all(elt is None or type(elt) is Poly
4462                  or type(core.get_aval(elt)) is ConcreteArray
4463                  for elt in (i.start, i.stop, i.step)):
4464        msg = ("Array slice indices must have static start/stop/step to be used "
4465               "with NumPy indexing syntax. To index a statically sized "
4466               "array at a dynamic position, try lax.dynamic_slice/"
4467               "dynamic_update_slice (JAX does not support dynamically sized "
4468               "arrays within JIT compiled functions).")
4469        raise IndexError(msg)
4470      start, limit, stride, needs_rev = _static_idx(i, x_shape[x_axis])
4471      if needs_rev:
4472        reversed_y_dims.append(collapsed_y_axis)
4473      if stride == 1:
4474        i = lax.convert_element_type(start, index_dtype)
4475        i = broadcast_to(i, tuple(gather_indices.shape[:-1]) + (1,))
4476        gather_indices = concatenate((gather_indices, i), -1)
4477        slice_shape.append(limit - start)
4478        gather_slice_shape.append(limit - start)
4479        offset_dims.append(collapsed_y_axis)
4480        start_index_map.append(x_axis)
4481      else:
4482        i = arange(start, limit, stride, dtype=index_dtype)
4483        size = i.shape[0]
4484        slice_shape.append(size)
4485        gather_slice_shape.append(1)
4486        gather_indices_shape = tuple(gather_indices.shape[:-1]) + (size,)
4487        i = lax.broadcast_in_dim(
4488            i, shape=gather_indices_shape + (1,),
4489            broadcast_dimensions=(len(gather_indices_shape) - 1,))
4490        gather_indices = lax.broadcast_in_dim(
4491            gather_indices,
4492            shape=gather_indices_shape + (len(start_index_map),),
4493            broadcast_dimensions=(
4494              tuple(range(len(gather_indices_shape) - 1)) +
4495              (len(gather_indices_shape),)))
4496        gather_indices = concatenate(
4497          (gather_indices, i), len(gather_indices_shape))
4498        start_index_map.append(x_axis)
4499        collapsed_slice_dims.append(x_axis)
4500
4501      collapsed_y_axis += 1
4502      y_axis += 1
4503      x_axis += 1
4504    else:
4505      if (abstract_i is not None and
4506          not (issubdtype(abstract_i.dtype, integer) or issubdtype(abstract_i.dtype, bool_))):
4507        msg = ("Indexer must have integer or boolean type, got indexer "
4508               "with type {} at position {}, indexer value {}")
4509        raise TypeError(msg.format(abstract_i.dtype.name, idx_pos, i))
4510
4511      msg = "Indexing mode not yet supported. Open a feature request!\n{}"
4512      raise IndexError(msg.format(idx))
4513
4514  dnums = lax.GatherDimensionNumbers(
4515    offset_dims = tuple(offset_dims),
4516    collapsed_slice_dims = tuple(sorted(collapsed_slice_dims)),
4517    start_index_map = tuple(start_index_map)
4518  )
4519  return _Indexer(
4520    slice_shape=slice_shape,
4521    newaxis_dims=tuple(newaxis_dims),
4522    gather_slice_shape=gather_slice_shape,
4523    reversed_y_dims=reversed_y_dims,
4524    dnums=dnums,
4525    gather_indices=gather_indices)
4526
4527def _should_unpack_list_index(x):
4528  """Helper for _eliminate_deprecated_list_indexing."""
4529  return (isinstance(x, ndarray) and np.ndim(x) != 0
4530          or isinstance(x, (Sequence, slice))
4531          or x is Ellipsis or x is None)
4532
4533def _eliminate_deprecated_list_indexing(idx):
4534  # "Basic slicing is initiated if the selection object is a non-array,
4535  # non-tuple sequence containing slice objects, [Ellipses, or newaxis
4536  # objects]". Detects this and raises a TypeError.
4537  if not isinstance(idx, tuple):
4538    if isinstance(idx, Sequence) and not isinstance(idx, ndarray):
4539      # As of numpy 1.16, some non-tuple sequences of indices result in a warning, while
4540      # others are converted to arrays, based on a set of somewhat convoluted heuristics
4541      # (See https://github.com/numpy/numpy/blob/v1.19.2/numpy/core/src/multiarray/mapping.c#L179-L343)
4542      # In JAX, we raise an informative TypeError for *all* non-tuple sequences.
4543      if _any(_should_unpack_list_index(i) for i in idx):
4544        msg = ("Using a non-tuple sequence for multidimensional indexing is not allowed; "
4545               "use `arr[tuple(seq)]` instead of `arr[seq]`. "
4546               "See https://github.com/google/jax/issues/4564 for more information.")
4547      else:
4548        msg = ("Using a non-tuple sequence for multidimensional indexing is not allowed; "
4549               "use `arr[array(seq)]` instead of `arr[seq]`. "
4550               "See https://github.com/google/jax/issues/4564 for more information.")
4551      raise TypeError(msg)
4552    else:
4553      idx = (idx,)
4554  return idx
4555
4556def _expand_bool_indices(idx):
4557  """Converts concrete bool indexes into advanced integer indexes."""
4558  out = []
4559  for i in idx:
4560    try:
4561      abstract_i = core.get_aval(i)
4562    except TypeError:
4563      abstract_i = None
4564    if (isinstance(abstract_i, ShapedArray) and issubdtype(abstract_i.dtype, bool_)
4565          or isinstance(i, list) and _all(not _shape(e) and issubdtype(_dtype(e), bool_)
4566                                          for e in i)):
4567      if isinstance(i, list):
4568        i = array(i)
4569        abstract_i = core.get_aval(i)
4570
4571      if not type(abstract_i) is ConcreteArray:
4572        # TODO(mattjj): improve this error by tracking _why_ the indices are not
4573        # concrete
4574        raise IndexError("Array boolean indices must be concrete.")
4575      else:
4576        out.extend(np.where(i))
4577    else:
4578      out.append(i)
4579  return tuple(out)
4580
4581def _is_slice_none(idx):
4582  """Return True if idx is equal to slice(None), False otherwise."""
4583  if isinstance(idx, slice):
4584    return idx.start is None and idx.stop is None and idx.step is None
4585
4586# TODO(mattjj): clean up this logic
4587def _is_advanced_int_indexer(idx):
4588  """Returns True if idx should trigger int array indexing, False otherwise."""
4589  # https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing
4590  assert isinstance(idx, tuple)
4591  if _all(np.ndim(elt) == 0 for elt in idx):
4592    return False
4593  return _all(e is None or e is Ellipsis or isinstance(e, slice)
4594              or _is_int_arraylike(e) for e in idx)
4595
4596def _is_int_arraylike(x):
4597  """Returns True if x is array-like with integer dtype, False otherwise."""
4598  return (isinstance(x, int) and not isinstance(x, bool)
4599          or issubdtype(getattr(x, "dtype", None), np.integer)
4600          or isinstance(x, (list, tuple)) and _all(_is_int_arraylike(e) for e in x))
4601
4602
4603def _canonicalize_tuple_index(arr_ndim, idx):
4604  """Helper to remove Ellipsis and add in the implicit trailing slice(None)."""
4605  len_without_none = _sum(1 for e in idx if e is not None and e is not Ellipsis)
4606  if len_without_none > arr_ndim:
4607    msg = "Too many indices for array: {} non-None/Ellipsis indices for dim {}."
4608    raise IndexError(msg.format(len_without_none, arr_ndim))
4609  ellipses = (i for i, elt in enumerate(idx) if elt is Ellipsis)
4610  ellipsis_index = next(ellipses, None)
4611  if ellipsis_index is not None:
4612    if next(ellipses, None) is not None:
4613      msg = "Multiple ellipses (...) not supported: {}."
4614      raise IndexError(msg.format(list(map(type, idx))))
4615    colons = (slice(None),) * (arr_ndim - len_without_none)
4616    idx = idx[:ellipsis_index] + colons + idx[ellipsis_index + 1:]
4617  elif len_without_none < arr_ndim:
4618    colons = (slice(None),) * (arr_ndim - len_without_none)
4619    idx = tuple(idx) + colons
4620  return idx
4621
4622def _polymorphic_slice_indices(idx: slice, size: Union[int, Poly]):
4623  # like idx.indices(size), but allows for polymorphic indices and size
4624  # see https://github.com/python/cpython/blob/6d6508765514c7c10719478a0430f5e47c9a96ac/Objects/sliceobject.c#L372
4625  assert isinstance(idx, slice)
4626
4627  step = 1 if idx.step is None else idx.step
4628  step_is_negative = step < 0
4629  lower = -1 if step_is_negative else 0
4630  upper = size + lower
4631
4632  def sanitize(index, default):
4633    if index is None:
4634      return default
4635    elif type(index) is Poly:
4636      return index
4637    elif index < 0:
4638      return _max(index + size, lower)
4639    else:
4640      return _min(index, upper)
4641
4642  start = sanitize(idx.start, default=upper if step_is_negative else lower)
4643  stop = sanitize(idx.stop, default=lower if step_is_negative else upper)
4644  return start, stop, step
4645
4646def _static_idx(idx: slice, size: Union[int, Poly]):
4647  """Helper function to compute the static slice start/limit/stride values."""
4648  if _any(type(s) is Poly for s in (idx.start, idx.stop, idx.step, size)):
4649    start, stop, step = _polymorphic_slice_indices(idx, size)
4650  elif isinstance(size, int):
4651    start, stop, step = idx.indices(size)
4652  else:
4653    raise TypeError(size)
4654
4655  if type(start) is not Poly and type(stop) is not Poly:
4656    if (step < 0 and stop >= start) or (step > 0 and start >= stop):
4657      return 0, 0, 1, False  # sliced to size zero
4658
4659  if step > 0:
4660    return start, stop, step, False
4661  else:
4662    k  = (start - stop - 1) % (-step)
4663    return stop + k + 1, start + 1, -step, True
4664
4665
4666blackman = _wrap_numpy_nullary_function(np.blackman)
4667bartlett = _wrap_numpy_nullary_function(np.bartlett)
4668hamming = _wrap_numpy_nullary_function(np.hamming)
4669hanning = _wrap_numpy_nullary_function(np.hanning)
4670# TODO: lower `kaiser` via lax to allow non-constant beta values.
4671kaiser = _wrap_numpy_nullary_function(np.kaiser)
4672
4673def _gcd_cond_fn(xs):
4674  x1, x2 = xs
4675  return any(x2 != 0)
4676
4677def _gcd_body_fn(xs):
4678  x1, x2 = xs
4679  x1, x2 = (where(x2 != 0, x2, x1),
4680            where(x2 != 0, lax.rem(x1, x2), lax._const(x2, 0)))
4681  return (where(x1 < x2, x2, x1), where(x1 < x2, x1, x2))
4682
4683@_wraps(getattr(np, "gcd", None))
4684def gcd(x1, x2):
4685  _check_arraylike("gcd", x1, x2)
4686  if (not issubdtype(_dtype(x1), integer) or
4687      not issubdtype(_dtype(x2), integer)):
4688    raise ValueError("Arguments to jax.numpy.gcd must be integers.")
4689  x1, x2 = _promote_dtypes(x1, x2)
4690  x1, x2 = broadcast_arrays(x1, x2)
4691  gcd, _ = lax.while_loop(_gcd_cond_fn, _gcd_body_fn, (abs(x1), abs(x2)))
4692  return gcd
4693
4694
4695@_wraps(getattr(np, "lcm", None))
4696def lcm(x1, x2):
4697  _check_arraylike("lcm", x1, x2)
4698  x1, x2 = _promote_dtypes(x1, x2)
4699  d = gcd(x1, x2)
4700  return where(d == 0, lax._const(d, 0),
4701               abs(multiply(x1, floor_divide(x2, d))))
4702
4703
4704@_wraps(np.extract)
4705def extract(condition, arr):
4706  return compress(ravel(condition), ravel(arr))
4707
4708
4709@_wraps(np.compress)
4710def compress(condition, a, axis: Optional[int] = None, out=None):
4711  if out is not None:
4712    raise NotImplementedError("The 'out' argument to jnp.compress is not supported.")
4713  if ndim(condition) != 1:
4714    raise ValueError("condition must be a 1D array")
4715  condition = asarray(condition).astype(bool)
4716  a = asarray(a)
4717  if axis is None:
4718    axis = 0
4719    a = ravel(a)
4720  else:
4721    a = moveaxis(a, axis, 0)
4722  condition, extra = condition[:a.shape[0]], condition[a.shape[0]:]
4723  if any(extra):
4724    raise ValueError("condition contains entries that are out of bounds")
4725  a = a[:condition.shape[0]]
4726  return moveaxis(a[condition], 0, axis)
4727
4728
4729@_wraps(np.cov)
4730def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None,
4731        aweights=None):
4732  if y is not None: raise NotImplementedError(
4733    "jax.numpy.cov not implemented for nontrivial y. "
4734    "Open a feature request at https://github.com/google/jax/issues !")
4735
4736  m, = _promote_args_inexact("cov", m)
4737
4738  if m.ndim > 2:
4739    raise ValueError("m has more than 2 dimensions")  # same as numpy error
4740  X = atleast_2d(m)
4741  if not rowvar and X.shape[0] != 1:
4742    X = X.T
4743  if X.shape[0] == 0:
4744    return array([]).reshape(0, 0)
4745  if ddof is None:
4746    ddof = 1 if bias == 0 else 0
4747
4748  w = None
4749  if fweights is not None:
4750    _check_arraylike("cov", fweights)
4751    if ndim(fweights) > 1:
4752      raise RuntimeError("cannot handle multidimensional fweights")
4753    if shape(fweights)[0] != X.shape[1]:
4754      raise RuntimeError("incompatible numbers of samples and fweights")
4755    if not issubdtype(_dtype(fweights), integer):
4756      raise TypeError("fweights must be integer.")
4757    # Ensure positive fweights; note that numpy raises an error on negative fweights.
4758    w = asarray(abs(fweights))
4759  if aweights is not None:
4760    _check_arraylike("cov", aweights)
4761    if ndim(aweights) > 1:
4762      raise RuntimeError("cannot handle multidimensional aweights")
4763    if shape(aweights)[0] != X.shape[1]:
4764      raise RuntimeError("incompatible numbers of samples and aweights")
4765    # Ensure positive aweights: note that numpy raises an error for negative aweights.
4766    aweights = abs(aweights)
4767    w = aweights if w is None else w * aweights
4768
4769  avg, w_sum = average(X, axis=1, weights=w, returned=True)
4770  w_sum = w_sum[0]
4771
4772  if w is None:
4773    f = X.shape[1] - ddof
4774  elif ddof == 0:
4775    f = w_sum
4776  elif aweights is None:
4777    f = w_sum - ddof
4778  else:
4779    f = w_sum - ddof * sum(w * aweights) / w_sum
4780
4781  X = X - avg[:, None]
4782  X_T = X.T if w is None else (X * w).T
4783  return true_divide(dot(X, X_T.conj()), f).squeeze()
4784
4785
4786@_wraps(np.corrcoef)
4787def corrcoef(x, y=None, rowvar=True):
4788  _check_arraylike("corrcoef", x)
4789  c = cov(x, y, rowvar)
4790  if len(shape(c)) == 0:
4791    # scalar - this should yield nan for values (nan/nan, inf/inf, 0/0), 1 otherwise
4792    return divide(c, c)
4793  d = diag(c)
4794  stddev = sqrt(real(d))
4795  c = divide(c, stddev[:,None])
4796  c = divide(c, stddev[None,:])
4797
4798  real_part = clip(real(c), -1, 1)
4799  if iscomplexobj(c):
4800    complex_part = clip(imag(c), -1, 1)
4801    c = lax.complex(real_part, complex_part)
4802  else:
4803    c = real_part
4804  return c
4805
4806
4807@_wraps(getattr(np, "quantile", None))
4808def quantile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
4809             overwrite_input=False, interpolation="linear", keepdims=False):
4810  _check_arraylike("quantile", a, q)
4811  if overwrite_input or out is not None:
4812    msg = ("jax.numpy.quantile does not support overwrite_input=True or "
4813           "out != None")
4814    raise ValueError(msg)
4815  return _quantile(a, q, axis, interpolation, keepdims, False)
4816
4817@_wraps(getattr(np, "nanquantile", None))
4818def nanquantile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None,
4819                out=None, overwrite_input=False, interpolation="linear",
4820                keepdims=False):
4821  _check_arraylike("nanquantile", a, q)
4822  if overwrite_input or out is not None:
4823    msg = ("jax.numpy.nanquantile does not support overwrite_input=True or "
4824           "out != None")
4825    raise ValueError(msg)
4826  return _quantile(a, q, axis, interpolation, keepdims, True)
4827
4828
4829@partial(jit, static_argnums=(2, 3, 4, 5))
4830def _quantile(a, q, axis, interpolation, keepdims, squash_nans):
4831  if interpolation not in ["linear", "lower", "higher", "midpoint", "nearest"]:
4832    raise ValueError("interpolation can only be 'linear', 'lower', 'higher', "
4833                     "'midpoint', or 'nearest'")
4834  a = asarray(a, dtype=promote_types(_dtype(a), float32))
4835  q = asarray(q, dtype=promote_types(_dtype(q), float32))
4836  if axis is None:
4837    a = ravel(a)
4838    axis = 0
4839  elif isinstance(axis, tuple):
4840    raise NotImplementedError("Tuple values for axis are not implemented")
4841  else:
4842    axis = _canonicalize_axis(axis, ndim(a))
4843
4844  q_shape = shape(q)
4845  q_ndim = ndim(q)
4846  if q_ndim > 1:
4847    raise ValueError("q must be have rank <= 1, got shape {}".format(shape(q)))
4848
4849  a_shape = shape(a)
4850  a = lax.sort(a, dimension=axis)
4851
4852  if squash_nans:
4853    counts = sum(logical_not(isnan(a)), axis=axis, dtype=q.dtype,
4854                 keepdims=keepdims)
4855    shape_after_reduction = counts.shape
4856    q = lax.expand_dims(
4857      q, tuple(range(q_ndim, len(shape_after_reduction) + q_ndim)))
4858    counts = lax.expand_dims(counts, tuple(range(q_ndim)))
4859    q = lax.mul(q, lax.sub(counts, _constant_like(q, 1)))
4860    low = lax.floor(q)
4861    high = lax.ceil(q)
4862    high_weight = lax.sub(q, low)
4863    low_weight = lax.sub(_constant_like(high_weight, 1), high_weight)
4864
4865    low = lax.max(_constant_like(low, 0), lax.min(low, counts - 1))
4866    high = lax.max(_constant_like(high, 0), lax.min(high, counts - 1))
4867    low = lax.convert_element_type(low, int64)
4868    high = lax.convert_element_type(high, int64)
4869    out_shape = q_shape + shape_after_reduction
4870    index = [lax.broadcasted_iota(int64, out_shape, dim + q_ndim)
4871             for dim in range(len(shape_after_reduction))]
4872    if keepdims:
4873      index[axis] = low
4874    else:
4875      index.insert(axis, low)
4876    low_value = a[tuple(index)]
4877    index[axis] = high
4878    high_value = a[tuple(index)]
4879  else:
4880    n = a_shape[axis]
4881    q = lax.mul(q, _constant_like(q, n - 1))
4882    low = lax.floor(q)
4883    high = lax.ceil(q)
4884    high_weight = lax.sub(q, low)
4885    low_weight = lax.sub(_constant_like(high_weight, 1), high_weight)
4886
4887    low = lax.clamp(_constant_like(low, 0), low, _constant_like(low, n - 1))
4888    high = lax.clamp(_constant_like(high, 0), high, _constant_like(high, n - 1))
4889    low = lax.convert_element_type(low, int64)
4890    high = lax.convert_element_type(high, int64)
4891
4892    slice_sizes = list(a_shape)
4893    slice_sizes[axis] = 1
4894    dnums = lax.GatherDimensionNumbers(
4895      offset_dims=tuple(range(
4896        q_ndim,
4897        len(a_shape) + q_ndim if keepdims else len(a_shape) + q_ndim - 1)),
4898      collapsed_slice_dims=() if keepdims else (axis,),
4899      start_index_map=(axis,))
4900    low_value = lax.gather(a, low[..., None], dimension_numbers=dnums,
4901                           slice_sizes=slice_sizes)
4902    high_value = lax.gather(a, high[..., None], dimension_numbers=dnums,
4903                            slice_sizes=slice_sizes)
4904    if q_ndim == 1:
4905      low_weight = lax.broadcast_in_dim(low_weight, low_value.shape,
4906                                        broadcast_dimensions=(0,))
4907      high_weight = lax.broadcast_in_dim(high_weight, high_value.shape,
4908                                        broadcast_dimensions=(0,))
4909
4910  if interpolation == "linear":
4911    result = lax.add(lax.mul(low_value.astype(q.dtype), low_weight),
4912                     lax.mul(high_value.astype(q.dtype), high_weight))
4913  elif interpolation == "lower":
4914    result = low_value
4915  elif interpolation == "higher":
4916    result = high_value
4917  elif interpolation == "nearest":
4918    pred = lax.le(high_weight, _constant_like(high_weight, 0.5))
4919    result = lax.select(pred, low_value, high_value)
4920  elif interpolation == "midpoint":
4921    result = lax.mul(lax.add(low_value, high_value), _constant_like(low_value, 0.5))
4922  else:
4923    raise ValueError(f"interpolation={interpolation!r} not recognized")
4924
4925  return lax.convert_element_type(result, a.dtype)
4926
4927
4928@partial(jit, static_argnums=2)
4929@partial(vectorize, excluded={0, 2})
4930def _searchsorted(a, v, side):
4931  if len(a) == 0:
4932    return 0
4933  op = operator.le if side == 'left' else operator.lt
4934
4935  def body_fun(i, state):
4936    low, high = state
4937    mid = (low + high) // 2
4938    go_left = op(v, a[mid])
4939    return (where(go_left, low, mid), where(go_left, mid, high))
4940
4941  n_levels = int(np.ceil(np.log2(len(a) + 1)))
4942  return lax.fori_loop(0, n_levels, body_fun, (0, len(a)))[1]
4943
4944
4945@_wraps(np.searchsorted)
4946def searchsorted(a, v, side='left', sorter=None):
4947  if side not in ['left', 'right']:
4948    raise ValueError(f"{side!r} is an invalid value for keyword 'side'")
4949  if sorter is not None:
4950    raise NotImplementedError("sorter is not implemented")
4951  a = asarray(a)
4952  v = asarray(v)
4953  if ndim(a) != 1:
4954    raise ValueError("a should be 1-dimensional")
4955  return _searchsorted(a, v, side)
4956
4957
4958@_wraps(np.digitize)
4959def digitize(x, bins, right=False):
4960  if len(bins) == 0:
4961    return zeros(x, dtype=dtypes.canonicalize_dtype(int_))
4962  side = 'right' if not right else 'left'
4963  return where(
4964    bins[-1] >= bins[0],
4965    searchsorted(bins, x, side=side),
4966    len(bins) - searchsorted(bins[::-1], x, side=side)
4967  )
4968
4969_PIECEWISE_DOC = """\
4970Unlike `np.piecewise`, :py:func:`jax.numpy.piecewise` requires functions in
4971`funclist` to be traceable by JAX, as it is implemeted via :func:`jax.lax.switch`.
4972See the :func:`jax.lax.switch` documentation for more information.
4973"""
4974
4975@_wraps(np.piecewise, lax_description=_PIECEWISE_DOC)
4976def piecewise(x, condlist, funclist, *args, **kw):
4977  _check_arraylike("piecewise", x)
4978  condlist = array(condlist, dtype=bool_)
4979  nc, nf = len(condlist), len(funclist)
4980  if nf == nc + 1:
4981    funclist = funclist[-1:] + funclist[:-1]
4982  elif nf == nc:
4983    funclist = [0] + list(funclist)
4984  else:
4985    raise ValueError(f"with {nc} condition(s), either {nc} or {nc+1} functions are expected; got {nf}")
4986  indices = argmax(cumsum(vstack([zeros_like(condlist[:1]), condlist]), 0), 0)
4987  dtype = _dtype(x)
4988  def _call(f):
4989    return lambda x: f(x, *args, **kw).astype(dtype)
4990  def _const(v):
4991    return lambda x: full_like(x, v)
4992  funclist = [_call(f) if callable(f) else _const(f) for f in funclist]
4993  return vectorize(lax.switch, excluded=(1,))(indices, funclist, x)
4994
4995
4996@_wraps(np.percentile)
4997def percentile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None,
4998               out=None, overwrite_input=False, interpolation="linear",
4999               keepdims=False):
5000  _check_arraylike("percentile", a)
5001  q = true_divide(asarray(q), float32(100.0))
5002  return quantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input,
5003                  interpolation=interpolation, keepdims=keepdims)
5004
5005@_wraps(np.nanpercentile)
5006def nanpercentile(a, q, axis: Optional[Union[int, Tuple[int, ...]]] = None,
5007                  out=None, overwrite_input=False, interpolation="linear",
5008                  keepdims=False):
5009  _check_arraylike("nanpercentile", a)
5010  q = true_divide(asarray(q), float32(100.0))
5011  return nanquantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input,
5012                     interpolation=interpolation, keepdims=keepdims)
5013
5014@_wraps(np.median)
5015def median(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
5016           overwrite_input=False, keepdims=False):
5017  _check_arraylike("median", a)
5018  return quantile(a, 0.5, axis=axis, out=out, overwrite_input=overwrite_input,
5019                  keepdims=keepdims, interpolation='midpoint')
5020
5021@_wraps(np.nanmedian)
5022def nanmedian(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None,
5023              overwrite_input=False, keepdims=False):
5024  _check_arraylike("nanmedian", a)
5025  return nanquantile(a, 0.5, axis=axis, out=out,
5026                     overwrite_input=overwrite_input, keepdims=keepdims,
5027                     interpolation='midpoint')
5028
5029
5030def _astype(arr, dtype):
5031  lax._check_user_dtype_supported(dtype, "astype")
5032  return lax.convert_element_type(arr, dtype)
5033
5034
5035def _nbytes(arr):
5036  return size(arr) * _dtype(arr).itemsize
5037
5038
5039def _view(arr, dtype=None, type=None):
5040  lax._check_user_dtype_supported(dtype, "view")
5041  if type is not None:
5042    raise NotImplementedError("`type` argument of array.view()")
5043  if dtype is None:
5044    return arr
5045  arr_dtype = _dtype(arr)
5046  if arr_dtype == dtype:
5047    return arr
5048  # bool is implemented as lax:PRED, which is not compatible with lax.bitcast_convert_type.
5049  # We work around this by casting bool to uint8.
5050  if arr_dtype == bool_:
5051    arr = arr.astype(uint8)
5052  nbits_in = 8 * arr_dtype.itemsize
5053  nbits_out = 8 * _dtype(dtype).itemsize
5054  if nbits_in == nbits_out:
5055    if dtype == bool_:
5056      return lax.bitcast_convert_type(arr, uint8).astype(dtype)
5057    return lax.bitcast_convert_type(arr, dtype)
5058  if nbits_out > nbits_in and (shape(arr)[-1] * nbits_in) % nbits_out != 0:
5059    raise ValueError("When changing to a larger dtype, its size must be a divisor "
5060                     "of the total size in bytes of the last axis of the array.")
5061  byte_dtypes = {8: uint8, 16: uint16, 32: uint32, 64: uint64}
5062  if nbits_in not in byte_dtypes:
5063    raise NotImplementedError(f"arr.view() for arr.dtype={arr_dtype}")
5064  if nbits_out not in byte_dtypes:
5065    raise NotImplementedError(f"arr.view(dtype) for dtype={dtype}")
5066  dt_in = byte_dtypes[nbits_in]
5067  dt_out = byte_dtypes[nbits_out]
5068  arr_bytes = lax.bitcast_convert_type(arr, dt_in)
5069  if nbits_in < nbits_out:
5070    shifts = arange(0, nbits_out, nbits_in, dtype=dt_out)
5071    arr_bytes = arr_bytes.reshape(arr.shape[:-1] + (-1, nbits_out // nbits_in)).astype(dt_out)
5072    arr_bytes = (arr_bytes << shifts).sum(-1).astype(dt_out)
5073  else:
5074    shifts = arange(0, nbits_in, nbits_out, dtype=dt_in)
5075    arr_bytes = ((arr_bytes[..., newaxis] >> shifts) & iinfo(dt_out).max).astype(dt_out)
5076    arr_bytes = arr_bytes.reshape(arr_bytes.shape[:-2] + (-1,))
5077  if dtype == bool_:
5078    return lax.bitcast_convert_type(arr_bytes, uint8).astype(dtype)
5079  return lax.bitcast_convert_type(arr_bytes, dtype)
5080
5081### track unimplemented functions
5082
5083_NOT_IMPLEMENTED_DESC = """
5084*** This function is not yet implemented by jax.numpy, and will raise NotImplementedError ***
5085"""
5086
5087def _not_implemented(fun):
5088  @_wraps(fun, update_doc=False, lax_description=_NOT_IMPLEMENTED_DESC)
5089  def wrapped(*args, **kwargs):
5090    msg = "Numpy function {} not yet implemented"
5091    raise NotImplementedError(msg.format(fun))
5092  return wrapped
5093
5094
5095### add method and operator overloads to arraylike classes
5096
5097# We add operator overloads to DeviceArray and ShapedArray. These method and
5098# operator overloads mainly just forward calls to the corresponding lax_numpy
5099# functions, which can themselves handle instances from any of these classes.
5100
5101_scalar_types = (int, float, complex, np.generic)
5102
5103def _defer_to_unrecognized_arg(binary_op):
5104  # Ensure that other array types have the chance to override arithmetic.
5105  def deferring_binary_op(self, other):
5106    if not isinstance(other, _scalar_types + _arraylike_types + (core.Tracer,)):
5107      return NotImplemented
5108    return binary_op(self, other)
5109  return deferring_binary_op
5110
5111def _swap_args(f):
5112  return lambda x, y: f(y, x)
5113
5114def _unimplemented_setitem(self, i, x):
5115  msg = ("'{}' object does not support item assignment. JAX arrays are "
5116         "immutable; perhaps you want jax.ops.index_update or "
5117         "jax.ops.index_add instead?")
5118  raise TypeError(msg.format(type(self)))
5119
5120def _operator_round(number, ndigits=None):
5121  out = round(number, decimals=ndigits or 0)
5122  # If `ndigits` is None, for a builtin float round(7.5) returns an integer.
5123  return out.astype(int) if ndigits is None else out
5124
5125_operators = {
5126    "getitem": _rewriting_take,
5127    "setitem": _unimplemented_setitem,
5128    "neg": negative,
5129    "pos": positive,
5130    "eq": _defer_to_unrecognized_arg(equal),
5131    "ne": _defer_to_unrecognized_arg(not_equal),
5132    "lt": _defer_to_unrecognized_arg(less),
5133    "le": _defer_to_unrecognized_arg(less_equal),
5134    "gt": _defer_to_unrecognized_arg(greater),
5135    "ge": _defer_to_unrecognized_arg(greater_equal),
5136    "abs": abs,
5137    "add": _defer_to_unrecognized_arg(add),
5138    "radd": _defer_to_unrecognized_arg(add),
5139    "sub": _defer_to_unrecognized_arg(subtract),
5140    "rsub": _defer_to_unrecognized_arg(_swap_args(subtract)),
5141    "mul": _defer_to_unrecognized_arg(multiply),
5142    "rmul": _defer_to_unrecognized_arg(multiply),
5143    "div": _defer_to_unrecognized_arg(divide),
5144    "rdiv": _defer_to_unrecognized_arg(_swap_args(divide)),
5145    "truediv": _defer_to_unrecognized_arg(true_divide),
5146    "rtruediv": _defer_to_unrecognized_arg(_swap_args(true_divide)),
5147    "floordiv": _defer_to_unrecognized_arg(floor_divide),
5148    "rfloordiv": _defer_to_unrecognized_arg(_swap_args(floor_divide)),
5149    "divmod": _defer_to_unrecognized_arg(divmod),
5150    "rdivmod": _defer_to_unrecognized_arg(_swap_args(divmod)),
5151    "mod": _defer_to_unrecognized_arg(mod),
5152    "rmod": _defer_to_unrecognized_arg(_swap_args(mod)),
5153    "pow": _defer_to_unrecognized_arg(power),
5154    "rpow": _defer_to_unrecognized_arg(_swap_args(power)),
5155    "matmul": _defer_to_unrecognized_arg(matmul),
5156    "rmatmul": _defer_to_unrecognized_arg(_swap_args(matmul)),
5157    "and": _defer_to_unrecognized_arg(bitwise_and),
5158    "rand": _defer_to_unrecognized_arg(bitwise_and),
5159    "or": _defer_to_unrecognized_arg(bitwise_or),
5160    "ror": _defer_to_unrecognized_arg(bitwise_or),
5161    "xor": _defer_to_unrecognized_arg(bitwise_xor),
5162    "rxor": _defer_to_unrecognized_arg(bitwise_xor),
5163    "invert": bitwise_not,
5164    "lshift": _defer_to_unrecognized_arg(left_shift),
5165    "rshift": _defer_to_unrecognized_arg(right_shift),
5166    "rlshift": _defer_to_unrecognized_arg(_swap_args(left_shift)),
5167    "rrshift": _defer_to_unrecognized_arg(_swap_args(right_shift)),
5168    "round": _operator_round,
5169}
5170
5171# These numpy.ndarray methods are just refs to an equivalent numpy function
5172_nondiff_methods = ["all", "any", "argmax", "argmin", "argpartition", "argsort",
5173                    "nonzero", "searchsorted", "round"]
5174_diff_methods = ["clip", "conj", "conjugate", "cumprod", "cumsum",
5175                 "diagonal", "dot", "max", "mean", "min", "prod", "ptp",
5176                 "ravel", "repeat", "sort", "squeeze", "std", "sum",
5177                 "swapaxes", "take", "tile", "trace", "transpose", "var"]
5178
5179# These methods are mentioned explicitly by nondiff_methods, so we create
5180# _not_implemented implementations of them here rather than in __init__.py.
5181# TODO(phawkins): implement these.
5182argpartition = _not_implemented(np.argpartition)
5183_NOT_IMPLEMENTED = ['argpartition']
5184
5185# Set up operator, method, and property forwarding on Tracer instances containing
5186# ShapedArray avals by following the forwarding conventions for Tracer.
5187# Forward operators using a single-underscore-prefix naming convention:
5188for operator_name, function in _operators.items():
5189  setattr(ShapedArray, "_{}".format(operator_name), staticmethod(function))
5190# Forward methods and properties using core.aval_method and core.aval_property:
5191for method_name in _nondiff_methods + _diff_methods:
5192  setattr(ShapedArray, method_name, core.aval_method(globals()[method_name]))
5193setattr(ShapedArray, "reshape", core.aval_method(_reshape))
5194setattr(ShapedArray, "flatten", core.aval_method(ravel))
5195setattr(ShapedArray, "T", core.aval_property(transpose))
5196setattr(ShapedArray, "real", core.aval_property(real))
5197setattr(ShapedArray, "imag", core.aval_property(imag))
5198setattr(ShapedArray, "astype", core.aval_method(_astype))
5199setattr(ShapedArray, "view", core.aval_method(_view))
5200setattr(ShapedArray, "nbytes", core.aval_property(_nbytes))
5201
5202
5203# Forward operators, methods, and properties on DeviceArray to lax_numpy
5204# functions (with no Tracers involved; this forwarding is direct)
5205for device_array in [_DeviceArray, _CppDeviceArray]:
5206  for operator_name, function in _operators.items():
5207    setattr(device_array, "__{}__".format(operator_name), function)
5208  for method_name in _nondiff_methods + _diff_methods:
5209    setattr(device_array, method_name, globals()[method_name])
5210  setattr(device_array, "reshape", _reshape)
5211  setattr(device_array, "flatten", ravel)
5212  setattr(device_array, "T", property(transpose))
5213  setattr(device_array, "real", property(real))
5214  setattr(device_array, "imag", property(imag))
5215  setattr(device_array, "astype", _astype)
5216  setattr(device_array, "view", _view)
5217  setattr(device_array, "nbytes", property(_nbytes))
5218
5219
5220# Experimental support for NumPy's module dispatch with NEP-37.
5221# Currently requires https://github.com/seberg/numpy-dispatch
5222_JAX_ARRAY_TYPES = (DeviceArray, core.Tracer)
5223_HANDLED_ARRAY_TYPES = _JAX_ARRAY_TYPES + (np.ndarray,)
5224
5225def __array_module__(self, types):
5226  if builtins.all(issubclass(t, _HANDLED_ARRAY_TYPES) for t in types):
5227    return jax.numpy
5228  else:
5229    return NotImplemented
5230
5231setattr(ShapedArray, "_array_module", staticmethod(__array_module__))
5232setattr(_DeviceArray, "__array_module__", __array_module__)
5233setattr(_CppDeviceArray, "__array_module__", __array_module__)
5234
5235
5236# Extra methods that are handy
5237setattr(ShapedArray, "broadcast", core.aval_method(lax.broadcast))
5238setattr(ShapedArray, "broadcast_in_dim", core.aval_method(lax.broadcast_in_dim))
5239setattr(ShapedArray, "split", core.aval_method(split))
5240for device_array in [_DeviceArray, _CppDeviceArray]:
5241  setattr(device_array, "broadcast", lax.broadcast)
5242  setattr(device_array, "broadcast_in_dim", lax.broadcast_in_dim)
5243  setattr(device_array, "split", split)
5244
5245def _compress_method(a, condition, axis=None, out=None):
5246  return compress(condition, a, axis, out)
5247
5248setattr(ShapedArray, "compress", _compress_method)
5249setattr(_DeviceArray, "compress", _compress_method)
5250setattr(_CppDeviceArray, "compress", _compress_method)
5251
5252@partial(jit, static_argnums=(1,2,3))
5253def _multi_slice(arr,
5254                 start_indices: Tuple[Tuple[int, ...]],
5255                 limit_indices: Tuple[Tuple[int, ...]],
5256                 removed_dims: Tuple[Tuple[int, ...]]):
5257  """Extracts multiple slices from `arr`.
5258
5259  This is used to shard DeviceArray arguments to pmap. It's implemented as a
5260  DeviceArray method here to avoid circular imports.
5261  """
5262  results = []
5263  for starts, limits, removed in safe_zip(start_indices, limit_indices, removed_dims):
5264    sliced = lax.slice(arr, starts, limits)
5265    if removed:
5266      sliced = sliced.reshape(np.delete(sliced.shape, removed_dims))
5267    results.append(sliced)
5268  return results
5269setattr(_DeviceArray, "_multi_slice", _multi_slice)
5270setattr(_CppDeviceArray, "_multi_slice", _multi_slice)
5271
5272
5273# Syntactic sugar for scatter operations.
5274class _IndexUpdateHelper:
5275  # Note: this docstring will appear as the docstring for the `at` property.
5276  """Indexable helper object to call indexed update functions.
5277
5278  The `at` property is syntactic sugar for calling the indexed update functions
5279  defined in :mod:`jax.ops`, and acts as a pure equivalent of in-place
5280  modificatons.
5281
5282  In particular:
5283  - ``x = x.at[idx].set(y)`` is a pure equivalent of ``x[idx] = y``.
5284  - ``x = x.at[idx].add(y)`` is a pure equivalent of ``x[idx] += y``.
5285  - ``x = x.at[idx].mul(y)`` is a pure equivalent of ``x[idx] *= y``.
5286  - ``x = x.at[idx].min(y)`` is a pure equivalent of
5287      ``x[idx] = minimum(x[idx], y)``.
5288  - ``x = x.at[idx].max(y)`` is a pure equivalent of
5289      ``x[idx] = maximum(x[idx], y)``.
5290  """
5291  __slots__ = ("array",)
5292
5293  def __init__(self, array):
5294    self.array = array
5295
5296  def __getitem__(self, index):
5297    return _IndexUpdateRef(self.array, index)
5298
5299  def __repr__(self):
5300    return f"_IndexUpdateHelper({repr(self.array)})"
5301
5302
5303class _IndexUpdateRef:
5304  """Helper object to call indexed update functions for an (advanced) index.
5305
5306  This object references a source array and a specific indexer into that array.
5307  Methods on this object return copies of the source array that have been
5308  modified at the positions specified by the indexer.
5309  """
5310  __slots__ = ("array", "index")
5311
5312  def __init__(self, array, index):
5313    self.array = array
5314    self.index = index
5315
5316  def __repr__(self):
5317    return f"_IndexUpdateRef({repr(self.array)}, {repr(self.index)})"
5318
5319  def set(self, values, indices_are_sorted=False, unique_indices=False):
5320    """Pure equivalent of ``x[idx] = y``.
5321
5322    ``x.at[idx].set(y)`` is syntactic sugar for
5323    ``jax.ops.index_update(x, jax.ops.index[idx], y)``, and
5324    returns the value of ``x`` that would result from the NumPy-style
5325    :mod:indexed assignment <numpy.doc.indexing>` ``x[idx] = y``.
5326
5327    See :mod:`jax.ops` for details.
5328    """
5329    return ops.index_update(self.array, self.index, values,
5330                            indices_are_sorted=indices_are_sorted,
5331                            unique_indices=unique_indices)
5332
5333  def add(self, values, indices_are_sorted=False, unique_indices=False):
5334    """Pure equivalent of ``x[idx] += y``.
5335
5336    ``x.at[idx].add(y)`` is syntactic sugar for
5337    ``jax.ops.index_add(x, jax.ops.index[idx], y)``, and
5338    returns the value of ``x`` that would result from the NumPy-style
5339    :mod:indexed assignment <numpy.doc.indexing>` ``x[idx] += y``.
5340
5341    See :mod:`jax.ops` for details.
5342    """
5343    return ops.index_add(self.array, self.index, values,
5344                         indices_are_sorted=indices_are_sorted,
5345                         unique_indices=unique_indices)
5346
5347  def mul(self, values, indices_are_sorted=False, unique_indices=False):
5348    """Pure equivalent of ``x[idx] += y``.
5349
5350    ``x.at[idx].mul(y)`` is syntactic sugar for
5351    ``jax.ops.index_mul(x, jax.ops.index[idx], y)``, and
5352    returns the value of ``x`` that would result from the NumPy-style
5353    :mod:indexed assignment <numpy.doc.indexing>` ``x[idx] *= y``.
5354
5355    See :mod:`jax.ops` for details.
5356    """
5357    return ops.index_mul(self.array, self.index, values,
5358                         indices_are_sorted=indices_are_sorted,
5359                         unique_indices=unique_indices)
5360
5361  def min(self, values, indices_are_sorted=False, unique_indices=False):
5362    """Pure equivalent of ``x[idx] = minimum(x[idx], y)``.
5363
5364    ``x.at[idx].min(y)`` is syntactic sugar for
5365    ``jax.ops.index_min(x, jax.ops.index[idx], y)``, and
5366    returns the value of ``x`` that would result from the NumPy-style
5367    :mod:indexed assignment <numpy.doc.indexing>`
5368    ``x[idx] = minimum(x[idx], y)``.
5369
5370    See :mod:`jax.ops` for details.
5371    """
5372    return ops.index_min(self.array, self.index, values,
5373                         indices_are_sorted=indices_are_sorted,
5374                         unique_indices=unique_indices)
5375
5376  def max(self, values, indices_are_sorted=False, unique_indices=False):
5377    """Pure equivalent of ``x[idx] = maximum(x[idx], y)``.
5378
5379    ``x.at[idx].max(y)`` is syntactic sugar for
5380    ``jax.ops.index_max(x, jax.ops.index[idx], y)``, and
5381    returns the value of ``x`` that would result from the NumPy-style
5382    :mod:indexed assignment <numpy.doc.indexing>`
5383    ``x[idx] = maximum(x[idx], y)``.
5384
5385    See :mod:`jax.ops` for details.
5386    """
5387    return ops.index_max(self.array, self.index, values,
5388                         indices_are_sorted=indices_are_sorted,
5389                         unique_indices=unique_indices)
5390
5391setattr(_DeviceArray, "at", property(_IndexUpdateHelper))
5392setattr(_CppDeviceArray, "at", property(_IndexUpdateHelper))
5393setattr(ShapedArray, "at", core.aval_property(_IndexUpdateHelper))
5394