1"""
2Spectral Algorithm for Nonlinear Equations
3"""
4import collections
5
6import numpy as np
7from scipy.optimize import OptimizeResult
8from scipy.optimize.optimize import _check_unknown_options
9from .linesearch import _nonmonotone_line_search_cruz, _nonmonotone_line_search_cheng
10
11class _NoConvergence(Exception):
12    pass
13
14
15def _root_df_sane(func, x0, args=(), ftol=1e-8, fatol=1e-300, maxfev=1000,
16                  fnorm=None, callback=None, disp=False, M=10, eta_strategy=None,
17                  sigma_eps=1e-10, sigma_0=1.0, line_search='cruz', **unknown_options):
18    r"""
19    Solve nonlinear equation with the DF-SANE method
20
21    Options
22    -------
23    ftol : float, optional
24        Relative norm tolerance.
25    fatol : float, optional
26        Absolute norm tolerance.
27        Algorithm terminates when ``||func(x)|| < fatol + ftol ||func(x_0)||``.
28    fnorm : callable, optional
29        Norm to use in the convergence check. If None, 2-norm is used.
30    maxfev : int, optional
31        Maximum number of function evaluations.
32    disp : bool, optional
33        Whether to print convergence process to stdout.
34    eta_strategy : callable, optional
35        Choice of the ``eta_k`` parameter, which gives slack for growth
36        of ``||F||**2``.  Called as ``eta_k = eta_strategy(k, x, F)`` with
37        `k` the iteration number, `x` the current iterate and `F` the current
38        residual. Should satisfy ``eta_k > 0`` and ``sum(eta, k=0..inf) < inf``.
39        Default: ``||F||**2 / (1 + k)**2``.
40    sigma_eps : float, optional
41        The spectral coefficient is constrained to ``sigma_eps < sigma < 1/sigma_eps``.
42        Default: 1e-10
43    sigma_0 : float, optional
44        Initial spectral coefficient.
45        Default: 1.0
46    M : int, optional
47        Number of iterates to include in the nonmonotonic line search.
48        Default: 10
49    line_search : {'cruz', 'cheng'}
50        Type of line search to employ. 'cruz' is the original one defined in
51        [Martinez & Raydan. Math. Comp. 75, 1429 (2006)], 'cheng' is
52        a modified search defined in [Cheng & Li. IMA J. Numer. Anal. 29, 814 (2009)].
53        Default: 'cruz'
54
55    References
56    ----------
57    .. [1] "Spectral residual method without gradient information for solving
58           large-scale nonlinear systems of equations." W. La Cruz,
59           J.M. Martinez, M. Raydan. Math. Comp. **75**, 1429 (2006).
60    .. [2] W. La Cruz, Opt. Meth. Software, 29, 24 (2014).
61    .. [3] W. Cheng, D.-H. Li. IMA J. Numer. Anal. **29**, 814 (2009).
62
63    """
64    _check_unknown_options(unknown_options)
65
66    if line_search not in ('cheng', 'cruz'):
67        raise ValueError("Invalid value %r for 'line_search'" % (line_search,))
68
69    nexp = 2
70
71    if eta_strategy is None:
72        # Different choice from [1], as their eta is not invariant
73        # vs. scaling of F.
74        def eta_strategy(k, x, F):
75            # Obtain squared 2-norm of the initial residual from the outer scope
76            return f_0 / (1 + k)**2
77
78    if fnorm is None:
79        def fnorm(F):
80            # Obtain squared 2-norm of the current residual from the outer scope
81            return f_k**(1.0/nexp)
82
83    def fmerit(F):
84        return np.linalg.norm(F)**nexp
85
86    nfev = [0]
87    f, x_k, x_shape, f_k, F_k, is_complex = _wrap_func(func, x0, fmerit, nfev, maxfev, args)
88
89    k = 0
90    f_0 = f_k
91    sigma_k = sigma_0
92
93    F_0_norm = fnorm(F_k)
94
95    # For the 'cruz' line search
96    prev_fs = collections.deque([f_k], M)
97
98    # For the 'cheng' line search
99    Q = 1.0
100    C = f_0
101
102    converged = False
103    message = "too many function evaluations required"
104
105    while True:
106        F_k_norm = fnorm(F_k)
107
108        if disp:
109            print("iter %d: ||F|| = %g, sigma = %g" % (k, F_k_norm, sigma_k))
110
111        if callback is not None:
112            callback(x_k, F_k)
113
114        if F_k_norm < ftol * F_0_norm + fatol:
115            # Converged!
116            message = "successful convergence"
117            converged = True
118            break
119
120        # Control spectral parameter, from [2]
121        if abs(sigma_k) > 1/sigma_eps:
122            sigma_k = 1/sigma_eps * np.sign(sigma_k)
123        elif abs(sigma_k) < sigma_eps:
124            sigma_k = sigma_eps
125
126        # Line search direction
127        d = -sigma_k * F_k
128
129        # Nonmonotone line search
130        eta = eta_strategy(k, x_k, F_k)
131        try:
132            if line_search == 'cruz':
133                alpha, xp, fp, Fp = _nonmonotone_line_search_cruz(f, x_k, d, prev_fs, eta=eta)
134            elif line_search == 'cheng':
135                alpha, xp, fp, Fp, C, Q = _nonmonotone_line_search_cheng(f, x_k, d, f_k, C, Q, eta=eta)
136        except _NoConvergence:
137            break
138
139        # Update spectral parameter
140        s_k = xp - x_k
141        y_k = Fp - F_k
142        sigma_k = np.vdot(s_k, s_k) / np.vdot(s_k, y_k)
143
144        # Take step
145        x_k = xp
146        F_k = Fp
147        f_k = fp
148
149        # Store function value
150        if line_search == 'cruz':
151            prev_fs.append(fp)
152
153        k += 1
154
155    x = _wrap_result(x_k, is_complex, shape=x_shape)
156    F = _wrap_result(F_k, is_complex)
157
158    result = OptimizeResult(x=x, success=converged,
159                            message=message,
160                            fun=F, nfev=nfev[0], nit=k)
161
162    return result
163
164
165def _wrap_func(func, x0, fmerit, nfev_list, maxfev, args=()):
166    """
167    Wrap a function and an initial value so that (i) complex values
168    are wrapped to reals, and (ii) value for a merit function
169    fmerit(x, f) is computed at the same time, (iii) iteration count
170    is maintained and an exception is raised if it is exceeded.
171
172    Parameters
173    ----------
174    func : callable
175        Function to wrap
176    x0 : ndarray
177        Initial value
178    fmerit : callable
179        Merit function fmerit(f) for computing merit value from residual.
180    nfev_list : list
181        List to store number of evaluations in. Should be [0] in the beginning.
182    maxfev : int
183        Maximum number of evaluations before _NoConvergence is raised.
184    args : tuple
185        Extra arguments to func
186
187    Returns
188    -------
189    wrap_func : callable
190        Wrapped function, to be called as
191        ``F, fp = wrap_func(x0)``
192    x0_wrap : ndarray of float
193        Wrapped initial value; raveled to 1-D and complex
194        values mapped to reals.
195    x0_shape : tuple
196        Shape of the initial value array
197    f : float
198        Merit function at F
199    F : ndarray of float
200        Residual at x0_wrap
201    is_complex : bool
202        Whether complex values were mapped to reals
203
204    """
205    x0 = np.asarray(x0)
206    x0_shape = x0.shape
207    F = np.asarray(func(x0, *args)).ravel()
208    is_complex = np.iscomplexobj(x0) or np.iscomplexobj(F)
209    x0 = x0.ravel()
210
211    nfev_list[0] = 1
212
213    if is_complex:
214        def wrap_func(x):
215            if nfev_list[0] >= maxfev:
216                raise _NoConvergence()
217            nfev_list[0] += 1
218            z = _real2complex(x).reshape(x0_shape)
219            v = np.asarray(func(z, *args)).ravel()
220            F = _complex2real(v)
221            f = fmerit(F)
222            return f, F
223
224        x0 = _complex2real(x0)
225        F = _complex2real(F)
226    else:
227        def wrap_func(x):
228            if nfev_list[0] >= maxfev:
229                raise _NoConvergence()
230            nfev_list[0] += 1
231            x = x.reshape(x0_shape)
232            F = np.asarray(func(x, *args)).ravel()
233            f = fmerit(F)
234            return f, F
235
236    return wrap_func, x0, x0_shape, fmerit(F), F, is_complex
237
238
239def _wrap_result(result, is_complex, shape=None):
240    """
241    Convert from real to complex and reshape result arrays.
242    """
243    if is_complex:
244        z = _real2complex(result)
245    else:
246        z = result
247    if shape is not None:
248        z = z.reshape(shape)
249    return z
250
251
252def _real2complex(x):
253    return np.ascontiguousarray(x, dtype=float).view(np.complex128)
254
255
256def _complex2real(z):
257    return np.ascontiguousarray(z, dtype=complex).view(np.float64)
258