1# Copyright 2018 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15
16import numpy as np
17from jax import lax
18from . import lax_numpy as jnp
19
20from jax import jit
21from .util import _wraps
22from .linalg import eigvals as _eigvals
23from jax import ops as jaxops
24
25
26def _to_inexact_type(type):
27  return type if jnp.issubdtype(type, jnp.inexact) else jnp.float_
28
29
30def _promote_inexact(arr):
31  return lax.convert_element_type(arr, _to_inexact_type(arr.dtype))
32
33
34@jit
35def _roots_no_zeros(p):
36  # assume: p does not have leading zeros and has length > 1
37  p = _promote_inexact(p)
38
39  # build companion matrix and find its eigenvalues (the roots)
40  A = jnp.diag(jnp.ones((p.size - 2,), p.dtype), -1)
41  A = jaxops.index_update(A, jaxops.index[0, :], -p[1:] / p[0])
42  roots = _eigvals(A)
43  return roots
44
45
46@jit
47def _nonzero_range(arr):
48  # return start and end s.t. arr[:start] = 0 = arr[end:] padding zeros
49  is_zero = arr == 0
50  start = jnp.argmin(is_zero)
51  end = is_zero.size - jnp.argmin(is_zero[::-1])
52  return start, end
53
54
55@_wraps(np.roots, lax_description="""\
56If the input polynomial coefficients of length n do not start with zero,
57the polynomial is of degree n - 1 leading to n - 1 roots.
58If the coefficients do have leading zeros, the polynomial they define
59has a smaller degree and the number of roots (and thus the output shape)
60is value dependent.
61
62The general implementation can therefore not be transformed with jit.
63If the coefficients are guaranteed to have no leading zeros, use the
64keyword argument `strip_zeros=False` to get a jit-compatible variant::
65
66    >>> roots_unsafe = jax.jit(functools.partial(jnp.roots, strip_zeros=False))
67    >>> roots_unsafe([1, 2])     # ok
68    DeviceArray([-2.+0.j], dtype=complex64)
69    >>> roots_unsafe([0, 1, 2])  # problem
70    DeviceArray([nan+nanj, nan+nanj], dtype=complex64)
71    >>> jnp.roots([0, 1, 2])         # use the no-jit version instead
72    DeviceArray([-2.+0.j], dtype=complex64)
73""")
74def roots(p, *, strip_zeros=True):
75  # ported from https://github.com/numpy/numpy/blob/v1.17.0/numpy/lib/polynomial.py#L168-L251
76  p = jnp.atleast_1d(p)
77  if p.ndim != 1:
78    raise ValueError("Input must be a rank-1 array.")
79
80  # strip_zeros=False is unsafe because leading zeros aren't removed
81  if not strip_zeros:
82    if p.size > 1:
83      return _roots_no_zeros(p)
84    else:
85      return jnp.array([])
86
87  if jnp.all(p == 0):
88    return jnp.array([])
89
90  # factor out trivial roots
91  start, end = _nonzero_range(p)
92  # number of trailing zeros = number of roots at 0
93  trailing_zeros = p.size - end
94
95  # strip leading and trailing zeros
96  p = p[start:end]
97
98  if p.size < 2:
99    return jnp.zeros(trailing_zeros, p.dtype)
100  else:
101    roots = _roots_no_zeros(p)
102    # combine roots and zero roots
103    roots = jnp.hstack((roots, jnp.zeros(trailing_zeros, p.dtype)))
104    return roots
105