1# Copyright 2020 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14from typing import Any, Callable, Mapping, Optional, Tuple, Union 15from .bfgs import minimize_bfgs 16from typing import NamedTuple 17import jax.numpy as jnp 18 19 20class OptimizeResults(NamedTuple): 21 """Object holding optimization results. 22 23 Parameters: 24 x: final solution. 25 success: True if optimization succeeded. 26 status: integer solver specific return code. 0 means nominal. 27 message: solver specific message. 28 fun: final function value. 29 jac: final jacobian array. 30 hess_inv: final inverse Hessian estimate. 31 nfev: integer number of funcation calls used. 32 njev: integer number of gradient evaluations. 33 nit: integer number of iterations of the optimization algorithm. 34 """ 35 x: jnp.ndarray 36 success: Union[bool, jnp.ndarray] 37 status: Union[int, jnp.ndarray] 38 message: str 39 fun: jnp.ndarray 40 jac: jnp.ndarray 41 hess_inv: jnp.ndarray 42 nfev: Union[int, jnp.ndarray] 43 njev: Union[int, jnp.ndarray] 44 nit: Union[int, jnp.ndarray] 45 46 47def minimize( 48 fun: Callable, 49 x0: jnp.ndarray, 50 args: Tuple = (), 51 *, 52 method: str, 53 tol: Optional[float] = None, 54 options: Optional[Mapping[str, Any]] = None, 55) -> OptimizeResults: 56 """Minimization of scalar function of one or more variables. 57 58 This API for this function matches SciPy with some minor deviations: 59 - Gradients of ``fun`` are calculated automatically using JAX's autodiff 60 support when required. 61 - The ``method`` argument is required. You must specify a solver. 62 - Various optional arguments in the SciPy interface have not yet been 63 implemented. 64 - Optimization results may differ from SciPy due to differences in the line 65 search implementation. 66 67 ``minimize`` supports ``jit`` compilation. It does not yet support 68 differentiation or arguments in the form of multi-dimensional arrays, but 69 support for both is planned. 70 71 Args: 72 fun: the objective function to be minimized, ``fun(x, *args) -> float``, 73 where ``x`` is an 1-D array with shape ``(n,)`` and ``args`` is a tuple 74 of the fixed parameters needed to completely specify the function. 75 ``fun`` must support differentiation. 76 x0: initial guess. Array of real elements of size ``(n,)``, where 'n' is 77 the number of independent variables. 78 args: extra arguments passed to the objective function. 79 method: solver type. Currently only "BFGS" is supported. 80 tol: tolerance for termination. For detailed control, use solver-specific 81 options. 82 options: a dictionary of solver options. All methods accept the following 83 generic options: 84 maxiter : int 85 Maximum number of iterations to perform. Depending on the 86 method each iteration may use several function evaluations. 87 88 Returns: OptimizeResults object. 89 """ 90 if options is None: 91 options = {} 92 93 fun_with_args = lambda x: fun(x, *args) 94 95 if method.lower() == 'bfgs': 96 results = minimize_bfgs(fun_with_args, x0, **options) 97 message = ("status meaning: 0=converged, 1=max BFGS iters reached, " 98 "3=zoom failed, 4=saddle point reached, " 99 "5=max line search iters reached, -1=undefined") 100 success = (results.converged) & jnp.logical_not(results.failed) 101 return OptimizeResults(x=results.x_k, 102 success=success, 103 status=results.status, 104 message=message, 105 fun=results.f_k, 106 jac=results.g_k, 107 hess_inv=results.H_k, 108 nfev=results.nfev, 109 njev=results.ngev, 110 nit=results.k) 111 112 raise ValueError("Method {} not recognized".format(method)) 113