1# -*- coding: utf-8 -*-
2# This file is part of QuTiP: Quantum Toolbox in Python.
3#
4#    Copyright (c) 2014 and later, Alexander J G Pitchford
5#    All rights reserved.
6#
7#    Redistribution and use in source and binary forms, with or without
8#    modification, are permitted provided that the following conditions are
9#    met:
10#
11#    1. Redistributions of source code must retain the above copyright notice,
12#       this list of conditions and the following disclaimer.
13#
14#    2. Redistributions in binary form must reproduce the above copyright
15#       notice, this list of conditions and the following disclaimer in the
16#       documentation and/or other materials provided with the distribution.
17#
18#    3. Neither the name of the QuTiP: Quantum Toolbox in Python nor the names
19#       of its contributors may be used to endorse or promote products derived
20#       from this software without specific prior written permission.
21#
22#    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
23#    "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
24#    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
25#    PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
26#    HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
27#    SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
28#    LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
29#    DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
30#    THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
31#    (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
32#    OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33###############################################################################
34
35# @author: Alexander Pitchford
36# @email1: agp1@aber.ac.uk
37# @email2: alex.pitchford@gmail.com
38# @organization: Aberystwyth University
39# @supervisor: Daniel Burgarth
40
41"""
42Classes here are expected to implement a run_optimization function
43that will use some method for optimising the control pulse, as defined
44by the control amplitudes. The system that the pulse acts upon are defined
45by the Dynamics object that must be passed in the instantiation.
46
47The methods are typically N dimensional function optimisers that
48find the minima of a fidelity error function. Note the number of variables
49for the fidelity function is the number of control timeslots,
50i.e. n_ctrls x Ntimeslots
51The methods will call functions on the Dynamics.fid_computer object,
52one or many times per interation,
53to get the fidelity error and gradient wrt to the amplitudes.
54The optimisation will stop when one of the termination conditions are met,
55for example: the fidelity aim has be reached, a local minima has been found,
56the maximum time allowed has been exceeded
57
58These function optimisation methods are so far from SciPy.optimize
59The two methods implemented are:
60
61    BFGS - Broyden–Fletcher–Goldfarb–Shanno algorithm
62
63        This a quasi second order Newton method. It uses successive calls to
64        the gradient function to make an estimation of the curvature (Hessian)
65        and hence direct its search for the function minima
66        The SciPy implementation is pure Python and hance is execution speed is
67        not high
68        use subclass: OptimizerBFGS
69
70    L-BFGS-B - Bounded, limited memory BFGS
71
72        This a version of the BFGS method where the Hessian approximation is
73        only based on a set of the most recent gradient calls. It generally
74        performs better where the are a large number of variables
75        The SciPy implementation of L-BFGS-B is wrapper around a well
76        established and actively maintained implementation in Fortran
77        Its is therefore very fast.
78        # See SciPy documentation for credit and details on the
79        # scipy.optimize.fmin_l_bfgs_b function
80        use subclass: OptimizerLBFGSB
81
82The baseclass Optimizer implements the function wrappers to the
83fidelity error, gradient, and iteration callback functions.
84These are called from the within the SciPy optimisation functions.
85The subclasses implement the algorithm specific pulse optimisation function.
86"""
87
88import functools
89import numpy as np
90import timeit
91import warnings
92from packaging.version import parse as _parse_version
93import scipy
94import scipy.optimize as spopt
95import copy
96import collections
97# QuTiP
98from qutip.qobj import Qobj
99import qutip.logging_utils as logging
100logger = logging.get_logger()
101# QuTiP control modules
102import qutip.control.optimresult as optimresult
103import qutip.control.termcond as termcond
104import qutip.control.errors as errors
105import qutip.control.dynamics as dynamics
106import qutip.control.pulsegen as pulsegen
107import qutip.control.dump as qtrldump
108
109
110# Older versions of SciPy use the method numpy.ndarray.tostring(), which has
111# been deprecated since Numpy 1.19 in favour of the identical-in-all-but-name
112# tobytes() method.  This is simply a deprecated call in SciPy, there's nothing
113# we or our users can do about it, and the function shouldn't actually be
114# removed from Numpy until at least 1.22, by which point we'll have been able
115# to drop support for SciPy 1.4.
116if _parse_version(scipy.__version__) < _parse_version("1.5"):
117    @functools.wraps(spopt.fmin_l_bfgs_b)
118    def fmin_l_bfgs_b(*args, **kwargs):
119        with warnings.catch_warnings():
120            message = r"tostring\(\) is deprecated\. Use tobytes\(\) instead\."
121            warnings.filterwarnings("ignore", message=message,
122                                    category=DeprecationWarning)
123            return spopt.fmin_l_bfgs_b(*args, **kwargs)
124else:
125    fmin_l_bfgs_b = spopt.fmin_l_bfgs_b
126
127
128def _is_string(var):
129    try:
130        if isinstance(var, basestring):
131            return True
132    except NameError:
133        try:
134            if isinstance(var, str):
135                return True
136        except:
137            return False
138    except:
139        return False
140
141    return False
142
143
144class Optimizer(object):
145    """
146    Base class for all control pulse optimisers. This class should not be
147    instantiated, use its subclasses.  This class implements the fidelity,
148    gradient and interation callback functions.  All subclass objects must be
149    initialised with a
150
151    - ``OptimConfig`` instance - various configuration options
152    - ``Dynamics`` instance - describes the dynamics of the (quantum) system
153      to be control optimised
154
155    Attributes
156    ----------
157    log_level : integer
158        level of messaging output from the logger.  Options are attributes of
159        qutip.logging_utils, in decreasing levels of messaging, are:
160        DEBUG_INTENSE, DEBUG_VERBOSE, DEBUG, INFO, WARN, ERROR, CRITICAL
161        Anything WARN or above is effectively 'quiet' execution, assuming
162        everything runs as expected.  The default NOTSET implies that the level
163        will be taken from the QuTiP settings file, which by default is WARN.
164
165    params:  Dictionary
166        The key value pairs are the attribute name and value. Note: attributes
167        are created if they do not exist already, and are overwritten if they
168        do.
169
170    alg : string
171        Algorithm to use in pulse optimisation.  Options are:
172
173        - 'GRAPE' (default) - GRadient Ascent Pulse Engineering
174        - 'CRAB' - Chopped RAndom Basis
175
176    alg_params : Dictionary
177        Options that are specific to the pulse optim algorithm ``alg``.
178
179    disp_conv_msg : bool
180        Set true to display a convergence message
181        (for scipy.optimize.minimize methods anyway)
182
183    optim_method : string
184        a scipy.optimize.minimize method that will be used to optimise
185        the pulse for minimum fidelity error
186
187    method_params : Dictionary
188        Options for the optim_method.
189        Note that where there is an equivalent attribute of this instance
190        or the termination_conditions (for example maxiter)
191        it will override an value in these options
192
193    approx_grad : bool
194        If set True then the method will approximate the gradient itself
195        (if it has requirement and facility for this)
196        This will mean that the fid_err_grad_wrapper will not get called
197        Note it should be left False when using the Dynamics
198        to calculate approximate gradients
199        Note it is set True automatically when the alg is CRAB
200
201    amp_lbound : float or list of floats
202        lower boundaries for the control amplitudes
203        Can be a scalar value applied to all controls
204        or a list of bounds for each control
205
206    amp_ubound : float or list of floats
207        upper boundaries for the control amplitudes
208        Can be a scalar value applied to all controls
209        or a list of bounds for each control
210
211    bounds : List of floats
212        Bounds for the parameters.
213        If not set before the run_optimization call then the list
214        is built automatically based on the amp_lbound and amp_ubound
215        attributes.
216        Setting this attribute directly allows specific bounds to be set
217        for individual parameters.
218        Note: Only some methods use bounds
219
220    dynamics : Dynamics (subclass instance)
221        describes the dynamics of the (quantum) system to be control optimised
222        (see Dynamics classes for details)
223
224    config : OptimConfig instance
225        various configuration options
226        (see OptimConfig for details)
227
228    termination_conditions : TerminationCondition instance
229        attributes determine when the optimisation will end
230
231    pulse_generator : PulseGen (subclass instance)
232        (can be) used to create initial pulses
233        not used by the class, but set by pulseoptim.create_pulse_optimizer
234
235    stats : Stats
236        attributes of which give performance stats for the optimisation
237        set to None to reduce overhead of calculating stats.
238        Note it is (usually) shared with the Dynamics instance
239
240    dump : :class:`dump.OptimDump`
241        Container for data dumped during the optimisation.
242        Can be set by specifying the dumping level or set directly.
243        Note this is mainly intended for user and a development debugging
244        but could be used for status information during a long optimisation.
245
246    dumping : string
247        level of data dumping: NONE, SUMMARY, FULL or CUSTOM
248        See property docstring for details
249
250    dump_to_file : bool
251        If set True then data will be dumped to file during the optimisation
252        dumping will be set to SUMMARY during init_optim
253        if dump_to_file is True and dumping not set.
254        Default is False
255
256    dump_dir : string
257        Basically a link to dump.dump_dir. Exists so that it can be set through
258        optim_params.
259        If dump is None then will return None or will set dumping to SUMMARY
260        when setting a path
261
262    iter_summary : :class:`OptimIterSummary`
263        Summary of the most recent iteration.
264        Note this is only set if dummping is on
265    """
266
267    def __init__(self, config, dyn, params=None):
268        self.dynamics = dyn
269        self.config = config
270        self.params = params
271        self.reset()
272        dyn.parent = self
273
274    def reset(self):
275        self.log_level = self.config.log_level
276        self.id_text = 'OPTIM'
277        self.termination_conditions = None
278        self.pulse_generator = None
279        self.disp_conv_msg = False
280        self.iteration_steps = None
281        self.record_iteration_steps=False
282        self.alg = 'GRAPE'
283        self.alg_params = None
284        self.method = 'l_bfgs_b'
285        self.method_params = None
286        self.method_options = None
287        self.approx_grad = False
288        self.amp_lbound = None
289        self.amp_ubound = None
290        self.bounds = None
291        self.num_iter = 0
292        self.num_fid_func_calls = 0
293        self.num_grad_func_calls = 0
294        self.stats = None
295        self.wall_time_optim_start = 0.0
296
297        self.dump_to_file = False
298        self.dump = None
299        self.iter_summary = None
300
301        # AJGP 2015-04-21:
302        # These (copying from config) are here for backward compatibility
303        if hasattr(self.config, 'amp_lbound'):
304            if self.config.amp_lbound:
305                self.amp_lbound = self.config.amp_lbound
306        if hasattr(self.config, 'amp_ubound'):
307            if self.config.amp_ubound:
308                self.amp_ubound = self.config.amp_ubound
309
310        self.apply_params()
311
312    @property
313    def log_level(self):
314        return logger.level
315
316    @log_level.setter
317    def log_level(self, lvl):
318        """
319        Set the log_level attribute and set the level of the logger
320        that is call logger.setLevel(lvl)
321        """
322        logger.setLevel(lvl)
323
324    def apply_params(self, params=None):
325        """
326        Set object attributes based on the dictionary (if any) passed in the
327        instantiation, or passed as a parameter
328        This is called during the instantiation automatically.
329        The key value pairs are the attribute name and value
330        Note: attributes are created if they do not exist already,
331        and are overwritten if they do.
332        """
333        if not params:
334            params = self.params
335
336        if isinstance(params, dict):
337            self.params = params
338            for key in params:
339                setattr(self, key, params[key])
340
341    @property
342    def dumping(self):
343        """
344        The level of data dumping that will occur during the optimisation
345
346        - NONE : No processing data dumped (Default)
347        - SUMMARY : A summary at each iteration will be recorded
348        - FULL : All logs will be generated and dumped
349        - CUSTOM : Some customised level of dumping
350
351        When first set to CUSTOM this is equivalent to SUMMARY. It is then up
352        to the user to specify which logs are dumped
353        """
354        if self.dump is None:
355            lvl = 'NONE'
356        else:
357            lvl = self.dump.level
358
359        return lvl
360
361    @dumping.setter
362    def dumping(self, value):
363        if value is None:
364            self.dump = None
365        else:
366            if not _is_string(value):
367                raise TypeError("Value must be string value")
368            lvl = value.upper()
369            if lvl == 'NONE':
370                self.dump = None
371            else:
372                if not isinstance(self.dump, qtrldump.OptimDump):
373                    self.dump = qtrldump.OptimDump(self, level=lvl)
374                else:
375                    self.dump.level = lvl
376    @property
377    def dump_dir(self):
378        if self.dump:
379            return self.dump.dump_dir
380        else:
381            return None
382
383    @dump_dir.setter
384    def dump_dir(self, value):
385        if not self.dump:
386            self.dumping = 'SUMMARY'
387        self.dump.dump_dir = value
388
389    def _create_result(self):
390        """
391        create the result object
392        and set the initial_amps attribute as the current amplitudes
393        """
394        result = optimresult.OptimResult()
395        result.initial_fid_err = self.dynamics.fid_computer.get_fid_err()
396        result.initial_amps = self.dynamics.ctrl_amps.copy()
397        result.evo_full_initial = self.dynamics.full_evo.copy()
398        result.time = self.dynamics.time.copy()
399        result.optimizer = self
400        return result
401
402    def init_optim(self, term_conds):
403        """
404        Check optimiser attribute status and passed parameters before
405        running the optimisation.
406        This is called by run_optimization, but could called independently
407        to check the configuration.
408        """
409        if term_conds is not None:
410            self.termination_conditions = term_conds
411        term_conds = self.termination_conditions
412
413        if not isinstance(term_conds, termcond.TerminationConditions):
414            raise errors.UsageError("No termination conditions for the "
415                                    "optimisation function")
416
417        if not isinstance(self.dynamics, dynamics.Dynamics):
418            raise errors.UsageError("No dynamics object attribute set")
419        self.dynamics.check_ctrls_initialized()
420
421        self.apply_method_params()
422
423        if term_conds.fid_err_targ is None and term_conds.fid_goal is None:
424            raise errors.UsageError("Either the goal or the fidelity "
425                                    "error tolerance must be set")
426
427        if term_conds.fid_err_targ is None:
428            term_conds.fid_err_targ = np.abs(1 - term_conds.fid_goal)
429
430        if term_conds.fid_goal is None:
431            term_conds.fid_goal = 1 - term_conds.fid_err_targ
432
433        if self.alg == 'CRAB':
434            self.approx_grad = True
435
436        if self.stats is not None:
437            self.stats.clear()
438
439        if self.dump_to_file:
440            if self.dump is None:
441                self.dumping = 'SUMMARY'
442            self.dump.write_to_file = True
443            self.dump.create_dump_dir()
444            logger.info("Optimiser dump will be written to:\n{}".format(
445                                        self.dump.dump_dir))
446
447        if self.dump:
448            self.iter_summary = OptimIterSummary()
449        else:
450            self.iter_summary = None
451
452        self.num_iter = 0
453        self.num_fid_func_calls = 0
454        self.num_grad_func_calls = 0
455        self.iteration_steps = None
456
457    def _build_method_options(self):
458        """
459        Creates the method_options dictionary for the scipy.optimize.minimize
460        function based on the attributes of this object and the
461        termination_conditions
462        It assumes that apply_method_params has already been run and
463        hence the method_options attribute may already contain items.
464        These values will NOT be overridden
465        """
466        tc = self.termination_conditions
467        if self.method_options is None:
468            self.method_options = {}
469        mo = self.method_options
470
471        if 'max_metric_corr' in mo and not 'maxcor' in mo:
472            mo['maxcor'] = mo['max_metric_corr']
473        elif hasattr(self, 'max_metric_corr') and not 'maxcor' in mo:
474            mo['maxcor'] = self.max_metric_corr
475        if 'accuracy_factor' in mo  and not 'ftol' in mo:
476            mo['ftol'] = mo['accuracy_factor']
477        elif hasattr(tc, 'accuracy_factor') and not 'ftol' in mo:
478            mo['ftol'] = tc.accuracy_factor
479        if tc.max_iterations > 0 and not 'maxiter' in mo:
480            mo['maxiter'] = tc.max_iterations
481        if tc.max_fid_func_calls > 0 and not 'maxfev' in mo:
482            mo['maxfev'] = tc.max_fid_func_calls
483        if tc.min_gradient_norm > 0 and not 'gtol' in mo:
484            mo['gtol'] = tc.min_gradient_norm
485        if not 'disp' in mo:
486            mo['disp'] = self.disp_conv_msg
487
488        return mo
489
490    def apply_method_params(self, params=None):
491        """
492        Loops through all the method_params
493        (either passed here or the method_params attribute)
494        If the name matches an attribute of this object or the
495        termination conditions object, then the value of this attribute
496        is set. Otherwise it is assumed to a method_option for the
497        scipy.optimize.minimize function
498        """
499        if not params:
500            params = self.method_params
501
502        if isinstance(params, dict):
503            self.method_params = params
504            unused_params = {}
505            for key in params:
506                val = params[key]
507                if hasattr(self, key):
508                    setattr(self, key, val)
509                if hasattr(self.termination_conditions, key):
510                    setattr(self.termination_conditions, key, val)
511                else:
512                    unused_params[key] = val
513
514            if len(unused_params) > 0:
515                if not isinstance(self.method_options, dict):
516                    self.method_options = unused_params
517                else:
518                    self.method_options.update(unused_params)
519
520    def _build_bounds_list(self):
521        cfg = self.config
522        dyn = self.dynamics
523        n_ctrls = dyn.num_ctrls
524        self.bounds = []
525        for t in range(dyn.num_tslots):
526            for c in range(n_ctrls):
527                if isinstance(self.amp_lbound, list):
528                    lb = self.amp_lbound[c]
529                else:
530                    lb = self.amp_lbound
531                if isinstance(self.amp_ubound, list):
532                    ub = self.amp_ubound[c]
533                else:
534                    ub = self.amp_ubound
535
536                if not lb is None and np.isinf(lb):
537                    lb = None
538                if not ub is None and np.isinf(ub):
539                    ub = None
540
541                self.bounds.append((lb, ub))
542
543    def run_optimization(self, term_conds=None):
544        """
545        This default function optimisation method is a wrapper to the
546        scipy.optimize.minimize function.
547
548        It will attempt to minimise the fidelity error with respect to some
549        parameters, which are determined by _get_optim_var_vals (see below)
550
551        The optimisation end when one of the passed termination conditions
552        has been met, e.g. target achieved, wall time, or
553        function call or iteration count exceeded. Note these
554        conditions include gradient minimum met (local minima) for
555        methods that use a gradient.
556
557        The function minimisation method is taken from the optim_method
558        attribute. Note that not all of these methods have been tested.
559        Note that some of these use a gradient and some do not.
560        See the scipy documentation for details. Options specific to the
561        method can be passed setting the method_params attribute.
562
563        If the parameter term_conds=None, then the termination_conditions
564        attribute must already be set. It will be overwritten if the
565        parameter is not None
566
567        The result is returned in an OptimResult object, which includes
568        the final fidelity, time evolution, reason for termination etc
569
570        """
571        self.init_optim(term_conds)
572        term_conds = self.termination_conditions
573        dyn = self.dynamics
574        cfg = self.config
575        self.optim_var_vals = self._get_optim_var_vals()
576        st_time = timeit.default_timer()
577        self.wall_time_optimize_start = st_time
578
579        if self.stats is not None:
580            self.stats.wall_time_optim_start = st_time
581            self.stats.wall_time_optim_end = 0.0
582            self.stats.num_iter = 0
583
584        if self.bounds is None:
585            self._build_bounds_list()
586
587        self._build_method_options()
588
589        result = self._create_result()
590
591        if self.approx_grad:
592            jac=None
593        else:
594            jac=self.fid_err_grad_wrapper
595
596        if self.log_level <= logging.INFO:
597            msg = ("Optimising pulse(s) using {} with "
598                        "minimise '{}' method").format(self.alg, self.method)
599            if self.approx_grad:
600                msg += " (approx grad)"
601            logger.info(msg)
602
603        try:
604            opt_res = spopt.minimize(
605                self.fid_err_func_wrapper, self.optim_var_vals,
606                method=self.method,
607                jac=jac,
608                bounds=self.bounds,
609                options=self.method_options,
610                callback=self.iter_step_callback_func)
611
612            amps = self._get_ctrl_amps(opt_res.x)
613            dyn.update_ctrl_amps(amps)
614            result.termination_reason = opt_res.message
615            # Note the iterations are counted in this object as well
616            # so there are compared here for interest sake only
617            if self.num_iter != opt_res.nit:
618                logger.info("The number of iterations counted {} "
619                            " does not match the number reported {} "
620                            "by {}".format(self.num_iter, opt_res.nit,
621                                            self.method))
622            result.num_iter = opt_res.nit
623
624        except errors.OptimizationTerminate as except_term:
625            self._interpret_term_exception(except_term, result)
626
627        end_time = timeit.default_timer()
628        self._add_common_result_attribs(result, st_time, end_time)
629
630        return result
631
632    def _get_optim_var_vals(self):
633        """
634        Generate the 1d array that holds the current variable values
635        of the function to be optimised
636        By default (as used in GRAPE) these are the control amplitudes
637        in each timeslot
638        """
639        return self.dynamics.ctrl_amps.reshape([-1])
640
641    def _get_ctrl_amps(self, optim_var_vals):
642        """
643        Get the control amplitudes from the current variable values
644        of the function to be optimised.
645        that is the 1d array that is passed from the optimisation method
646        Note for GRAPE these are the function optimiser parameters
647        (and this is the default)
648
649        Returns
650        -------
651        float array[dynamics.num_tslots, dynamics.num_ctrls]
652        """
653        amps = optim_var_vals.reshape(self.dynamics.ctrl_amps.shape)
654
655        return amps
656
657    def fid_err_func_wrapper(self, *args):
658        """
659        Get the fidelity error achieved using the ctrl amplitudes passed
660        in as the first argument.
661
662        This is called by generic optimisation algorithm as the
663        func to the minimised. The argument is the current
664        variable values, i.e. control amplitudes, passed as
665        a flat array. Hence these are reshaped as [nTimeslots, n_ctrls]
666        and then used to update the stored ctrl values (if they have changed)
667
668        The error is checked against the target, and the optimisation is
669        terminated if the target has been achieved.
670        """
671        self.num_fid_func_calls += 1
672        # *** update stats ***
673        if self.stats is not None:
674            self.stats.num_fidelity_func_calls = self.num_fid_func_calls
675            if self.log_level <= logging.DEBUG:
676                logger.debug("fidelity error call {}".format(
677                    self.stats.num_fidelity_func_calls))
678
679        amps = self._get_ctrl_amps(args[0].copy())
680        self.dynamics.update_ctrl_amps(amps)
681
682        tc = self.termination_conditions
683        err = self.dynamics.fid_computer.get_fid_err()
684
685        if self.iter_summary:
686            self.iter_summary.fid_func_call_num = self.num_fid_func_calls
687            self.iter_summary.fid_err = err
688
689        if self.dump and self.dump.dump_fid_err:
690            self.dump.update_fid_err_log(err)
691
692        if err <= tc.fid_err_targ:
693            raise errors.GoalAchievedTerminate(err)
694
695        if self.num_fid_func_calls > tc.max_fid_func_calls:
696            raise errors.MaxFidFuncCallTerminate()
697
698        return err
699
700    def fid_err_grad_wrapper(self, *args):
701        """
702        Get the gradient of the fidelity error with respect to all of the
703        variables, i.e. the ctrl amplidutes in each timeslot
704
705        This is called by generic optimisation algorithm as the gradients of
706        func to the minimised wrt the variables. The argument is the current
707        variable values, i.e. control amplitudes, passed as
708        a flat array. Hence these are reshaped as [nTimeslots, n_ctrls]
709        and then used to update the stored ctrl values (if they have changed)
710
711        Although the optimisation algorithms have a check within them for
712        function convergence, i.e. local minima, the sum of the squares
713        of the normalised gradient is checked explicitly, and the
714        optimisation is terminated if this is below the min_gradient_norm
715        condition
716        """
717        # *** update stats ***
718        self.num_grad_func_calls += 1
719        if self.stats is not None:
720            self.stats.num_grad_func_calls = self.num_grad_func_calls
721            if self.log_level <= logging.DEBUG:
722                logger.debug("gradient call {}".format(
723                    self.stats.num_grad_func_calls))
724        amps = self._get_ctrl_amps(args[0].copy())
725        self.dynamics.update_ctrl_amps(amps)
726        fid_comp = self.dynamics.fid_computer
727        # gradient_norm_func is a pointer to the function set in the config
728        # that returns the normalised gradients
729        grad = fid_comp.get_fid_err_gradient()
730
731        if self.iter_summary:
732            self.iter_summary.grad_func_call_num = self.num_grad_func_calls
733            self.iter_summary.grad_norm = fid_comp.grad_norm
734
735        if self.dump:
736            if self.dump.dump_grad_norm:
737                self.dump.update_grad_norm_log(fid_comp.grad_norm)
738
739            if self.dump.dump_grad:
740                self.dump.update_grad_log(grad)
741
742        tc = self.termination_conditions
743        if fid_comp.grad_norm < tc.min_gradient_norm:
744            raise errors.GradMinReachedTerminate(fid_comp.grad_norm)
745        return grad.flatten()
746
747    def iter_step_callback_func(self, *args):
748        """
749        Check the elapsed wall time for the optimisation run so far.
750        Terminate if this has exceeded the maximum allowed time
751        """
752        self.num_iter += 1
753
754        if self.log_level <= logging.DEBUG:
755            logger.debug("Iteration callback {}".format(self.num_iter))
756
757        wall_time = timeit.default_timer() - self.wall_time_optimize_start
758
759        if self.iter_summary:
760            self.iter_summary.iter_num = self.num_iter
761            self.iter_summary.wall_time = wall_time
762
763        if self.dump and self.dump.dump_summary:
764            self.dump.add_iter_summary()
765
766        tc = self.termination_conditions
767
768        if wall_time > tc.max_wall_time:
769            raise errors.MaxWallTimeTerminate()
770
771        # *** update stats ***
772        if self.stats is not None:
773            self.stats.num_iter = self.num_iter
774
775    def _interpret_term_exception(self, except_term, result):
776        """
777        Update the result object based on the exception that occurred
778        during the optimisation
779        """
780        result.termination_reason = except_term.reason
781        if isinstance(except_term, errors.GoalAchievedTerminate):
782            result.goal_achieved = True
783        elif isinstance(except_term, errors.MaxWallTimeTerminate):
784            result.wall_time_limit_exceeded = True
785        elif isinstance(except_term, errors.GradMinReachedTerminate):
786            result.grad_norm_min_reached = True
787        elif isinstance(except_term, errors.MaxFidFuncCallTerminate):
788            result.max_fid_func_exceeded = True
789
790    def _add_common_result_attribs(self, result, st_time, end_time):
791        """
792        Update the result object attributes which are common to all
793        optimisers and outcomes
794        """
795        dyn = self.dynamics
796        result.num_iter = self.num_iter
797        result.num_fid_func_calls = self.num_fid_func_calls
798        result.wall_time = end_time - st_time
799        result.fid_err = dyn.fid_computer.get_fid_err()
800        result.grad_norm_final = dyn.fid_computer.grad_norm
801        result.final_amps = dyn.ctrl_amps
802        final_evo = dyn.full_evo
803        if isinstance(final_evo, Qobj):
804            result.evo_full_final = final_evo
805        else:
806            result.evo_full_final = Qobj(final_evo, dims=dyn.sys_dims)
807        # *** update stats ***
808        if self.stats is not None:
809            self.stats.wall_time_optim_end = end_time
810            self.stats.calculate()
811            result.stats = copy.copy(self.stats)
812
813
814class OptimizerBFGS(Optimizer):
815    """
816    Implements the run_optimization method using the BFGS algorithm
817    """
818    def reset(self):
819        Optimizer.reset(self)
820        self.id_text = 'BFGS'
821
822    def run_optimization(self, term_conds=None):
823        """
824        Optimise the control pulse amplitudes to minimise the fidelity error
825        using the BFGS (Broyden–Fletcher–Goldfarb–Shanno) algorithm
826        The optimisation end when one of the passed termination conditions
827        has been met, e.g. target achieved, gradient minimum met
828        (local minima), wall time / iteration count exceeded.
829
830        Essentially this is wrapper to the:
831        scipy.optimize.fmin_bfgs
832        function
833
834        If the parameter term_conds=None, then the termination_conditions
835        attribute must already be set. It will be overwritten if the
836        parameter is not None
837
838        The result is returned in an OptimResult object, which includes
839        the final fidelity, time evolution, reason for termination etc
840        """
841        self.init_optim(term_conds)
842        term_conds = self.termination_conditions
843        dyn = self.dynamics
844        self.optim_var_vals = self._get_optim_var_vals()
845        self._build_method_options()
846
847        st_time = timeit.default_timer()
848        self.wall_time_optimize_start = st_time
849
850        if self.stats is not None:
851            self.stats.wall_time_optim_start = st_time
852            self.stats.wall_time_optim_end = 0.0
853            self.stats.num_iter = 1
854
855        if self.approx_grad:
856            fprime = None
857        else:
858            fprime = self.fid_err_grad_wrapper
859
860        if self.log_level <= logging.INFO:
861            msg = ("Optimising pulse(s) using {} with "
862                        "'fmin_bfgs' method").format(self.alg)
863            if self.approx_grad:
864                msg += " (approx grad)"
865            logger.info(msg)
866
867        result = self._create_result()
868        try:
869            optim_var_vals, cost, grad, invHess, nFCalls, nGCalls, warn = \
870                spopt.fmin_bfgs(self.fid_err_func_wrapper,
871                                self.optim_var_vals,
872                                fprime=fprime,
873#                                approx_grad=self.approx_grad,
874                                callback=self.iter_step_callback_func,
875                                gtol=term_conds.min_gradient_norm,
876                                maxiter=term_conds.max_iterations,
877                                full_output=True, disp=True)
878
879            amps = self._get_ctrl_amps(optim_var_vals)
880            dyn.update_ctrl_amps(amps)
881            if warn == 1:
882                result.max_iter_exceeded = True
883                result.termination_reason = "Iteration count limit reached"
884            elif warn == 2:
885                result.grad_norm_min_reached = True
886                result.termination_reason = "Gradient normal minimum reached"
887
888        except errors.OptimizationTerminate as except_term:
889            self._interpret_term_exception(except_term, result)
890
891        end_time = timeit.default_timer()
892        self._add_common_result_attribs(result, st_time, end_time)
893
894        return result
895
896
897class OptimizerLBFGSB(Optimizer):
898    """
899    Implements the run_optimization method using the L-BFGS-B algorithm
900
901    Attributes
902    ----------
903    max_metric_corr : integer
904        The maximum number of variable metric corrections used to define
905        the limited memory matrix. That is the number of previous
906        gradient values that are used to approximate the Hessian
907        see the scipy.optimize.fmin_l_bfgs_b documentation for description
908        of m argument
909
910    """
911
912    def reset(self):
913        Optimizer.reset(self)
914        self.id_text = 'LBFGSB'
915        self.max_metric_corr = 10
916        self.msg_level = None
917
918    def init_optim(self, term_conds):
919        """
920        Check optimiser attribute status and passed parameters before
921        running the optimisation.
922        This is called by run_optimization, but could called independently
923        to check the configuration.
924        """
925        if term_conds is None:
926            term_conds = self.termination_conditions
927
928        # AJGP 2015-04-21:
929        # These (copying from config) are here for backward compatibility
930        if hasattr(self.config, 'max_metric_corr'):
931            if self.config.max_metric_corr:
932                self.max_metric_corr = self.config.max_metric_corr
933        if hasattr(self.config, 'accuracy_factor'):
934            if self.config.accuracy_factor:
935                term_conds.accuracy_factor = \
936                            self.config.accuracy_factor
937
938        Optimizer.init_optim(self, term_conds)
939
940        if not isinstance(self.msg_level, int):
941            if self.log_level < logging.DEBUG:
942                self.msg_level  = 2
943            elif self.log_level <= logging.DEBUG:
944                self.msg_level  = 1
945            else:
946                self.msg_level  = 0
947
948    def run_optimization(self, term_conds=None):
949        """
950        Optimise the control pulse amplitudes to minimise the fidelity error
951        using the L-BFGS-B algorithm, which is the constrained
952        (bounded amplitude values), limited memory, version of the
953        Broyden–Fletcher–Goldfarb–Shanno algorithm.
954
955        The optimisation end when one of the passed termination conditions
956        has been met, e.g. target achieved, gradient minimum met
957        (local minima), wall time / iteration count exceeded.
958
959        Essentially this is wrapper to the:
960        scipy.optimize.fmin_l_bfgs_b function
961        This in turn is a warpper for well established implementation of
962        the L-BFGS-B algorithm written in Fortran, which is therefore
963        very fast. See SciPy documentation for credit and details on
964        this function.
965
966        If the parameter term_conds=None, then the termination_conditions
967        attribute must already be set. It will be overwritten if the
968        parameter is not None
969
970        The result is returned in an OptimResult object, which includes
971        the final fidelity, time evolution, reason for termination etc
972
973        """
974        self.init_optim(term_conds)
975        term_conds = self.termination_conditions
976        dyn = self.dynamics
977        cfg = self.config
978        self.optim_var_vals = self._get_optim_var_vals()
979        self._build_method_options()
980
981        st_time = timeit.default_timer()
982        self.wall_time_optimize_start = st_time
983
984        if self.stats is not None:
985            self.stats.wall_time_optim_start = st_time
986            self.stats.wall_time_optim_end = 0.0
987            self.stats.num_iter = 1
988
989        bounds = self._build_bounds_list()
990        result = self._create_result()
991
992        if self.approx_grad:
993            fprime = None
994        else:
995            fprime = self.fid_err_grad_wrapper
996
997
998        if 'accuracy_factor' in self.method_options:
999            factr = self.method_options['accuracy_factor']
1000        elif 'ftol' in self.method_options:
1001            factr = self.method_options['ftol']
1002        elif hasattr(term_conds, 'accuracy_factor'):
1003            factr = term_conds.accuracy_factor
1004        else:
1005            factr = 1e7
1006
1007        if 'max_metric_corr' in self.method_options:
1008            m = self.method_options['max_metric_corr']
1009        elif 'maxcor' in self.method_options:
1010            m = self.method_options['maxcor']
1011        elif hasattr(self, 'max_metric_corr'):
1012            m = self.max_metric_corr
1013        else:
1014            m = 10
1015
1016        if self.log_level <= logging.INFO:
1017            msg = ("Optimising pulse(s) using {} with "
1018                        "'fmin_l_bfgs_b' method").format(self.alg)
1019            if self.approx_grad:
1020                msg += " (approx grad)"
1021            logger.info(msg)
1022        try:
1023            optim_var_vals, fid, res_dict = fmin_l_bfgs_b(
1024                self.fid_err_func_wrapper, self.optim_var_vals,
1025                fprime=fprime,
1026                approx_grad=self.approx_grad,
1027                callback=self.iter_step_callback_func,
1028                bounds=self.bounds, m=m, factr=factr,
1029                pgtol=term_conds.min_gradient_norm,
1030                disp=self.msg_level,
1031                maxfun=term_conds.max_fid_func_calls,
1032                maxiter=term_conds.max_iterations)
1033
1034            amps = self._get_ctrl_amps(optim_var_vals)
1035            dyn.update_ctrl_amps(amps)
1036            warn = res_dict['warnflag']
1037            if warn == 0:
1038                result.grad_norm_min_reached = True
1039                result.termination_reason = "function converged"
1040            elif warn == 1:
1041                result.max_iter_exceeded = True
1042                result.termination_reason = ("Iteration or fidelity "
1043                                             "function call limit reached")
1044            elif warn == 2:
1045                result.termination_reason = res_dict['task']
1046
1047            result.num_iter = res_dict['nit']
1048        except errors.OptimizationTerminate as except_term:
1049            self._interpret_term_exception(except_term, result)
1050
1051        end_time = timeit.default_timer()
1052        self._add_common_result_attribs(result, st_time, end_time)
1053
1054        return result
1055
1056class OptimizerCrab(Optimizer):
1057    """
1058    Optimises the pulse using the CRAB algorithm [1].
1059    It uses the scipy.optimize.minimize function with the method specified
1060    by the optim_method attribute. See Optimizer.run_optimization for details
1061    It minimises the fidelity error function with respect to the CRAB
1062    basis function coefficients.
1063
1064    AJGP ToDo: Add citation here
1065    """
1066
1067    def reset(self):
1068        Optimizer.reset(self)
1069        self.id_text = 'CRAB'
1070        self.num_optim_vars = 0
1071
1072    def init_optim(self, term_conds):
1073        """
1074        Check optimiser attribute status and passed parameters before
1075        running the optimisation.
1076        This is called by run_optimization, but could called independently
1077        to check the configuration.
1078        """
1079        Optimizer.init_optim(self, term_conds)
1080        dyn = self.dynamics
1081
1082        self.num_optim_vars = 0
1083        pulse_gen_valid = True
1084        # check the pulse generators match the ctrls
1085        # (in terms of number)
1086        # and count the number of parameters
1087        if self.pulse_generator is None:
1088            pulse_gen_valid = False
1089            err_msg = "pulse_generator attribute is None"
1090        elif not isinstance(self.pulse_generator, collections.abc.Iterable):
1091            pulse_gen_valid = False
1092            err_msg = "pulse_generator is not iterable"
1093
1094        elif len(self.pulse_generator) != dyn.num_ctrls:
1095            pulse_gen_valid = False
1096            err_msg = ("the number of pulse generators {} does not equal "
1097                        "the number of controls {}".format(
1098                        len(self.pulse_generator), dyn.num_ctrls))
1099
1100        if pulse_gen_valid:
1101            for p_gen in self.pulse_generator:
1102                if not isinstance(p_gen, pulsegen.PulseGenCrab):
1103                    pulse_gen_valid = False
1104                    err_msg = (
1105                        "pulse_generator contained object of type '{}'".format(
1106                        p_gen.__class__.__name__))
1107                    break
1108                self.num_optim_vars += p_gen.num_optim_vars
1109
1110        if not pulse_gen_valid:
1111            raise errors.UsageError(
1112                "The pulse_generator attribute must be set to a list of "
1113                "PulseGenCrab - one for each control. Here " + err_msg)
1114
1115    def _build_bounds_list(self):
1116        """
1117        No bounds necessary here, as the bounds for the CRAB parameters
1118        do not have much physical meaning.
1119        This needs to override the default method, otherwise the shape
1120        will be wrong
1121        """
1122        return None
1123
1124    def _get_optim_var_vals(self):
1125        """
1126        Generate the 1d array that holds the current variable values
1127        of the function to be optimised
1128        For CRAB these are the basis coefficients
1129
1130        Returns
1131        -------
1132        ndarray (1d) of float
1133
1134        """
1135        pvals = []
1136        for pgen in self.pulse_generator:
1137            pvals.extend(pgen.get_optim_var_vals())
1138
1139        return np.array(pvals)
1140
1141    def _get_ctrl_amps(self, optim_var_vals):
1142        """
1143        Get the control amplitudes from the current variable values
1144        of the function to be optimised.
1145        that is the 1d array that is passed from the optimisation method
1146        For CRAB the amplitudes will need to calculated by expanding the
1147        series
1148
1149        Returns
1150        -------
1151        float array[dynamics.num_tslots, dynamics.num_ctrls]
1152        """
1153        dyn = self.dynamics
1154
1155        if self.log_level <= logging.DEBUG:
1156            changed_params = self.optim_var_vals != optim_var_vals
1157            logger.debug(
1158                "{} out of {} optimisation parameters changed".format(
1159                    changed_params.sum(), len(optim_var_vals)))
1160
1161        amps = np.empty([dyn.num_tslots, dyn.num_ctrls])
1162        j = 0
1163        param_idx_st = 0
1164        for p_gen in self.pulse_generator:
1165            param_idx_end = param_idx_st + p_gen.num_optim_vars
1166            pg_pvals = optim_var_vals[param_idx_st:param_idx_end]
1167            p_gen.set_optim_var_vals(pg_pvals)
1168            amps[:, j] = p_gen.gen_pulse()
1169            param_idx_st = param_idx_end
1170            j += 1
1171
1172        #print("param_idx_end={}".format(param_idx_end))
1173        self.optim_var_vals = optim_var_vals
1174        return amps
1175
1176class OptimizerCrabFmin(OptimizerCrab):
1177    """
1178    Optimises the pulse using the CRAB algorithm [1]_, [2]_.
1179    It uses the ``scipy.optimize.fmin`` function which is effectively a wrapper
1180    for the Nelder-Mead method.  It minimises the fidelity error function with
1181    respect to the CRAB basis function coefficients.  This is the default
1182    Optimizer for CRAB.
1183
1184    References
1185    ----------
1186    .. [1] P. Doria, T. Calarco & S. Montangero. Phys. Rev. Lett. 106, 190501
1187       (2011).
1188    .. [2] T. Caneva, T. Calarco, & S. Montangero. Phys. Rev. A 84, 022326
1189       (2011).
1190    """
1191
1192    def reset(self):
1193        OptimizerCrab.reset(self)
1194        self.id_text = 'CRAB_FMIN'
1195        self.xtol = 1e-4
1196        self.ftol = 1e-4
1197
1198    def run_optimization(self, term_conds=None):
1199        """
1200        This function optimisation method is a wrapper to the
1201        scipy.optimize.fmin function.
1202
1203        It will attempt to minimise the fidelity error with respect to some
1204        parameters, which are determined by _get_optim_var_vals which
1205        in the case of CRAB are the basis function coefficients
1206
1207        The optimisation end when one of the passed termination conditions
1208        has been met, e.g. target achieved, wall time, or
1209        function call or iteration count exceeded. Specifically to the fmin
1210        method, the optimisation will stop when change parameter values
1211        is less than xtol or the change in function value is below ftol.
1212
1213        If the parameter term_conds=None, then the termination_conditions
1214        attribute must already be set. It will be overwritten if the
1215        parameter is not None
1216
1217        The result is returned in an OptimResult object, which includes
1218        the final fidelity, time evolution, reason for termination etc
1219        """
1220        self.init_optim(term_conds)
1221        term_conds = self.termination_conditions
1222        dyn = self.dynamics
1223        cfg = self.config
1224        self.optim_var_vals = self._get_optim_var_vals()
1225        self._build_method_options()
1226
1227        #print("Initial values:\n{}".format(self.optim_var_vals))
1228        st_time = timeit.default_timer()
1229        self.wall_time_optimize_start = st_time
1230
1231        if self.stats is not None:
1232            self.stats.wall_time_optim_start = st_time
1233            self.stats.wall_time_optim_end = 0.0
1234            self.stats.num_iter = 1
1235
1236        result = self._create_result()
1237
1238        if self.log_level <= logging.INFO:
1239            logger.info("Optimising pulse(s) using {} with "
1240                        "'fmin' (Nelder-Mead) method".format(self.alg))
1241
1242        try:
1243            ret = spopt.fmin(
1244                    self.fid_err_func_wrapper, self.optim_var_vals,
1245                    xtol=self.xtol, ftol=self.ftol,
1246                    maxiter=term_conds.max_iterations,
1247                    maxfun=term_conds.max_fid_func_calls,
1248                    full_output=True, disp=self.disp_conv_msg,
1249                    retall=self.record_iteration_steps,
1250                    callback=self.iter_step_callback_func)
1251
1252            final_param_vals = ret[0]
1253            num_iter = ret[2]
1254            warn_flag = ret[4]
1255            if self.record_iteration_steps:
1256                self.iteration_steps = ret[5]
1257            amps = self._get_ctrl_amps(final_param_vals)
1258            dyn.update_ctrl_amps(amps)
1259
1260            # Note the iterations are counted in this object as well
1261            # so there are compared here for interest sake only
1262            if self.num_iter != num_iter:
1263                logger.info("The number of iterations counted {} "
1264                            " does not match the number reported {} "
1265                            "by {}".format(self.num_iter, num_iter,
1266                                            self.method))
1267            result.num_iter = num_iter
1268            if warn_flag == 0:
1269                result.termination_reason = \
1270                    "Function converged (within tolerance)"
1271            elif warn_flag == 1:
1272                result.termination_reason = \
1273                    "Maximum number of function evaluations reached"
1274                result.max_fid_func_exceeded = True
1275            elif warn_flag == 2:
1276                result.termination_reason = \
1277                    "Maximum number of iterations reached"
1278                result.max_iter_exceeded = True
1279            else:
1280                result.termination_reason = \
1281                    "Unknown (warn_flag={})".format(warn_flag)
1282
1283        except errors.OptimizationTerminate as except_term:
1284            self._interpret_term_exception(except_term, result)
1285
1286        end_time = timeit.default_timer()
1287        self._add_common_result_attribs(result, st_time, end_time)
1288
1289        return result
1290
1291class OptimIterSummary(qtrldump.DumpSummaryItem):
1292    """A summary of the most recent iteration of the pulse optimisation
1293
1294    Attributes
1295    ----------
1296    iter_num : int
1297        Iteration number of the pulse optimisation
1298
1299    fid_func_call_num : int
1300        Fidelity function call number of the pulse optimisation
1301
1302    grad_func_call_num : int
1303        Gradient function call number of the pulse optimisation
1304
1305    fid_err : float
1306        Fidelity error
1307
1308    grad_norm : float
1309        fidelity gradient (wrt the control parameters) vector norm
1310        that is the magnitude of the gradient
1311
1312    wall_time : float
1313        Time spent computing the pulse optimisation so far
1314        (in seconds of elapsed time)
1315    """
1316    # Note there is some duplication here with Optimizer attributes
1317    # this exists solely to be copied into the summary dump
1318    min_col_width = 11
1319    summary_property_names = (
1320        "idx", "iter_num", "fid_func_call_num", "grad_func_call_num",
1321        "fid_err", "grad_norm", "wall_time"
1322        )
1323
1324    summary_property_fmt_type = (
1325        'd', 'd', 'd', 'd',
1326        'g', 'g', 'g'
1327        )
1328
1329    summary_property_fmt_prec = (
1330        0, 0, 0, 0,
1331        4, 4, 2
1332        )
1333
1334    def __init__(self):
1335        self.reset()
1336
1337    def reset(self):
1338        qtrldump.DumpSummaryItem.reset(self)
1339        self.iter_num = None
1340        self.fid_func_call_num = None
1341        self.grad_func_call_num = None
1342        self.fid_err = None
1343        self.grad_norm = None
1344        self.wall_time = 0.0
1345