1# Copyright 2020 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14from typing import Any, Callable, Mapping, Optional, Tuple, Union
15from .bfgs import minimize_bfgs
16from typing import NamedTuple
17import jax.numpy as jnp
18
19
20class OptimizeResults(NamedTuple):
21  """Object holding optimization results.
22
23  Parameters:
24    x: final solution.
25    success: True if optimization succeeded.
26    status: integer solver specific return code. 0 means nominal.
27    message: solver specific message.
28    fun: final function value.
29    jac: final jacobian array.
30    hess_inv: final inverse Hessian estimate.
31    nfev: integer number of funcation calls used.
32    njev: integer number of gradient evaluations.
33    nit: integer number of iterations of the optimization algorithm.
34  """
35  x: jnp.ndarray
36  success: Union[bool, jnp.ndarray]
37  status: Union[int, jnp.ndarray]
38  message: str
39  fun: jnp.ndarray
40  jac: jnp.ndarray
41  hess_inv: jnp.ndarray
42  nfev: Union[int, jnp.ndarray]
43  njev: Union[int, jnp.ndarray]
44  nit: Union[int, jnp.ndarray]
45
46
47def minimize(
48    fun: Callable,
49    x0: jnp.ndarray,
50    args: Tuple = (),
51    *,
52    method: str,
53    tol: Optional[float] = None,
54    options: Optional[Mapping[str, Any]] = None,
55) -> OptimizeResults:
56  """Minimization of scalar function of one or more variables.
57
58  This API for this function matches SciPy with some minor deviations:
59  - Gradients of ``fun`` are calculated automatically using JAX's autodiff
60    support when required.
61  - The ``method`` argument is required. You must specify a solver.
62  - Various optional arguments in the SciPy interface have not yet been
63    implemented.
64  - Optimization results may differ from SciPy due to differences in the line
65    search implementation.
66
67  ``minimize`` supports ``jit`` compilation. It does not yet support
68   differentiation or arguments in the form of multi-dimensional arrays, but
69   support for both is planned.
70
71  Args:
72    fun: the objective function to be minimized, ``fun(x, *args) -> float``,
73      where ``x`` is an 1-D array with shape ``(n,)`` and ``args`` is a tuple
74      of the fixed parameters needed to completely specify the function.
75      ``fun`` must support differentiation.
76    x0: initial guess. Array of real elements of size ``(n,)``, where 'n' is
77      the number of independent variables.
78    args: extra arguments passed to the objective function.
79    method: solver type. Currently only "BFGS" is supported.
80    tol: tolerance for termination. For detailed control, use solver-specific
81      options.
82    options: a dictionary of solver options. All methods accept the following
83      generic options:
84        maxiter : int
85            Maximum number of iterations to perform. Depending on the
86            method each iteration may use several function evaluations.
87
88  Returns: OptimizeResults object.
89  """
90  if options is None:
91    options = {}
92
93  fun_with_args = lambda x: fun(x, *args)
94
95  if method.lower() == 'bfgs':
96    results = minimize_bfgs(fun_with_args, x0, **options)
97    message = ("status meaning: 0=converged, 1=max BFGS iters reached, "
98               "3=zoom failed, 4=saddle point reached, "
99               "5=max line search iters reached, -1=undefined")
100    success = (results.converged) & jnp.logical_not(results.failed)
101    return OptimizeResults(x=results.x_k,
102                           success=success,
103                           status=results.status,
104                           message=message,
105                           fun=results.f_k,
106                           jac=results.g_k,
107                           hess_inv=results.H_k,
108                           nfev=results.nfev,
109                           njev=results.ngev,
110                           nit=results.k)
111
112  raise ValueError("Method {} not recognized".format(method))
113