1# Copyright 2018 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15 16import numpy as np 17from jax import lax 18from . import lax_numpy as jnp 19 20from jax import jit 21from .util import _wraps 22from .linalg import eigvals as _eigvals 23from jax import ops as jaxops 24 25 26def _to_inexact_type(type): 27 return type if jnp.issubdtype(type, jnp.inexact) else jnp.float_ 28 29 30def _promote_inexact(arr): 31 return lax.convert_element_type(arr, _to_inexact_type(arr.dtype)) 32 33 34@jit 35def _roots_no_zeros(p): 36 # assume: p does not have leading zeros and has length > 1 37 p = _promote_inexact(p) 38 39 # build companion matrix and find its eigenvalues (the roots) 40 A = jnp.diag(jnp.ones((p.size - 2,), p.dtype), -1) 41 A = jaxops.index_update(A, jaxops.index[0, :], -p[1:] / p[0]) 42 roots = _eigvals(A) 43 return roots 44 45 46@jit 47def _nonzero_range(arr): 48 # return start and end s.t. arr[:start] = 0 = arr[end:] padding zeros 49 is_zero = arr == 0 50 start = jnp.argmin(is_zero) 51 end = is_zero.size - jnp.argmin(is_zero[::-1]) 52 return start, end 53 54 55@_wraps(np.roots, lax_description="""\ 56If the input polynomial coefficients of length n do not start with zero, 57the polynomial is of degree n - 1 leading to n - 1 roots. 58If the coefficients do have leading zeros, the polynomial they define 59has a smaller degree and the number of roots (and thus the output shape) 60is value dependent. 61 62The general implementation can therefore not be transformed with jit. 63If the coefficients are guaranteed to have no leading zeros, use the 64keyword argument `strip_zeros=False` to get a jit-compatible variant:: 65 66 >>> roots_unsafe = jax.jit(functools.partial(jnp.roots, strip_zeros=False)) 67 >>> roots_unsafe([1, 2]) # ok 68 DeviceArray([-2.+0.j], dtype=complex64) 69 >>> roots_unsafe([0, 1, 2]) # problem 70 DeviceArray([nan+nanj, nan+nanj], dtype=complex64) 71 >>> jnp.roots([0, 1, 2]) # use the no-jit version instead 72 DeviceArray([-2.+0.j], dtype=complex64) 73""") 74def roots(p, *, strip_zeros=True): 75 # ported from https://github.com/numpy/numpy/blob/v1.17.0/numpy/lib/polynomial.py#L168-L251 76 p = jnp.atleast_1d(p) 77 if p.ndim != 1: 78 raise ValueError("Input must be a rank-1 array.") 79 80 # strip_zeros=False is unsafe because leading zeros aren't removed 81 if not strip_zeros: 82 if p.size > 1: 83 return _roots_no_zeros(p) 84 else: 85 return jnp.array([]) 86 87 if jnp.all(p == 0): 88 return jnp.array([]) 89 90 # factor out trivial roots 91 start, end = _nonzero_range(p) 92 # number of trailing zeros = number of roots at 0 93 trailing_zeros = p.size - end 94 95 # strip leading and trailing zeros 96 p = p[start:end] 97 98 if p.size < 2: 99 return jnp.zeros(trailing_zeros, p.dtype) 100 else: 101 roots = _roots_no_zeros(p) 102 # combine roots and zero roots 103 roots = jnp.hstack((roots, jnp.zeros(trailing_zeros, p.dtype))) 104 return roots 105