1import os
2import numpy as np
3from ase import io, units
4from ase.optimize import QuasiNewton
5from ase.parallel import paropen, world
6from ase.md import VelocityVerlet
7from ase.md import MDLogger
8from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
9
10
11class MinimaHopping:
12    """Implements the minima hopping method of global optimization outlined
13    by S. Goedecker,  J. Chem. Phys. 120: 9911 (2004). Initialize with an
14    ASE atoms object. Optional parameters are fed through keywords.
15    To run multiple searches in parallel, specify the minima_traj keyword,
16    and have each run point to the same path.
17    """
18
19    _default_settings = {
20        'T0': 1000.,  # K, initial MD 'temperature'
21        'beta1': 1.1,  # temperature adjustment parameter
22        'beta2': 1.1,  # temperature adjustment parameter
23        'beta3': 1. / 1.1,  # temperature adjustment parameter
24        'Ediff0': 0.5,  # eV, initial energy acceptance threshold
25        'alpha1': 0.98,  # energy threshold adjustment parameter
26        'alpha2': 1. / 0.98,  # energy threshold adjustment parameter
27        'mdmin': 2,  # criteria to stop MD simulation (no. of minima)
28        'logfile': 'hop.log',  # text log
29        'minima_threshold': 0.5,  # A, threshold for identical configs
30        'timestep': 1.0,  # fs, timestep for MD simulations
31        'optimizer': QuasiNewton,  # local optimizer to use
32        'minima_traj': 'minima.traj',  # storage file for minima list
33        'fmax': 0.05}  # eV/A, max force for optimizations
34
35    def __init__(self, atoms, **kwargs):
36        """Initialize with an ASE atoms object and keyword arguments."""
37        self._atoms = atoms
38        for key in kwargs:
39            if key not in self._default_settings:
40                raise RuntimeError('Unknown keyword: %s' % key)
41        for k, v in self._default_settings.items():
42            setattr(self, '_%s' % k, kwargs.pop(k, v))
43
44        # when a MD sim. has passed a local minimum:
45        self._passedminimum = PassedMinimum()
46
47        # Misc storage.
48        self._previous_optimum = None
49        self._previous_energy = None
50        self._temperature = self._T0
51        self._Ediff = self._Ediff0
52
53    def __call__(self, totalsteps=None, maxtemp=None):
54        """Run the minima hopping algorithm. Can specify stopping criteria
55        with total steps allowed or maximum searching temperature allowed.
56        If neither is specified, runs indefinitely (or until stopped by
57        batching software)."""
58        self._startup()
59        while True:
60            if (totalsteps and self._counter >= totalsteps):
61                self._log('msg', 'Run terminated. Step #%i reached of '
62                          '%i allowed. Increase totalsteps if resuming.'
63                          % (self._counter, totalsteps))
64                return
65            if (maxtemp and self._temperature >= maxtemp):
66                self._log('msg', 'Run terminated. Temperature is %.2f K;'
67                          ' max temperature allowed %.2f K.'
68                          % (self._temperature, maxtemp))
69                return
70
71            self._previous_optimum = self._atoms.copy()
72            self._previous_energy = self._atoms.get_potential_energy()
73            self._molecular_dynamics()
74            self._optimize()
75            self._counter += 1
76            self._check_results()
77
78    def _startup(self):
79        """Initiates a run, and determines if running from previous data or
80        a fresh run."""
81
82        status = np.array(-1.)
83        exists = self._read_minima()
84        if world.rank == 0:
85            if not exists:
86                # Fresh run with new minima file.
87                status = np.array(0.)
88            elif not os.path.exists(self._logfile):
89                # Fresh run with existing or shared minima file.
90                status = np.array(1.)
91            else:
92                # Must be resuming from within a working directory.
93                status = np.array(2.)
94        world.barrier()
95        world.broadcast(status, 0)
96
97        if status == 2.:
98            self._resume()
99        else:
100            self._counter = 0
101            self._log('init')
102            self._log('msg', 'Performing initial optimization.')
103            if status == 1.:
104                self._log('msg', 'Using existing minima file with %i prior '
105                          'minima: %s' % (len(self._minima),
106                                          self._minima_traj))
107            self._optimize()
108            self._check_results()
109            self._counter += 1
110
111    def _resume(self):
112        """Attempt to resume a run, based on information in the log
113        file. Note it will almost always be interrupted in the middle of
114        either a qn or md run or when exceeding totalsteps, so it only has
115        been tested in those cases currently."""
116        f = paropen(self._logfile, 'r')
117        lines = f.read().splitlines()
118        f.close()
119        self._log('msg', 'Attempting to resume stopped run.')
120        self._log('msg', 'Using existing minima file with %i prior '
121                  'minima: %s' % (len(self._minima), self._minima_traj))
122        mdcount, qncount = 0, 0
123        for line in lines:
124            if (line[:4] == 'par:') and ('Ediff' not in line):
125                self._temperature = float(line.split()[1])
126                self._Ediff = float(line.split()[2])
127            elif line[:18] == 'msg: Optimization:':
128                qncount = int(line[19:].split('qn')[1])
129            elif line[:24] == 'msg: Molecular dynamics:':
130                mdcount = int(line[25:].split('md')[1])
131        self._counter = max((mdcount, qncount))
132        if qncount == mdcount:
133            # Either stopped during local optimization or terminated due to
134            # max steps.
135            self._log('msg', 'Attempting to resume at qn%05i' % qncount)
136            if qncount > 0:
137                atoms = io.read('qn%05i.traj' % (qncount - 1), index=-1)
138                self._previous_optimum = atoms.copy()
139                self._previous_energy = atoms.get_potential_energy()
140            if os.path.getsize('qn%05i.traj' % qncount) > 0:
141                atoms = io.read('qn%05i.traj' % qncount, index=-1)
142            else:
143                atoms = io.read('md%05i.traj' % qncount, index=-3)
144            self._atoms.positions = atoms.get_positions()
145            fmax = np.sqrt((atoms.get_forces() ** 2).sum(axis=1).max())
146            if fmax < self._fmax:
147                # Stopped after a qn finished.
148                self._log('msg', 'qn%05i fmax already less than fmax=%.3f'
149                          % (qncount, self._fmax))
150                self._counter += 1
151                return
152            self._optimize()
153            self._counter += 1
154            if qncount > 0:
155                self._check_results()
156            else:
157                self._record_minimum()
158                self._log('msg', 'Found a new minimum.')
159                self._log('msg', 'Accepted new minimum.')
160                self._log('par')
161        elif qncount < mdcount:
162            # Probably stopped during molecular dynamics.
163            self._log('msg', 'Attempting to resume at md%05i.' % mdcount)
164            atoms = io.read('qn%05i.traj' % qncount, index=-1)
165            self._previous_optimum = atoms.copy()
166            self._previous_energy = atoms.get_potential_energy()
167            self._molecular_dynamics(resume=mdcount)
168            self._optimize()
169            self._counter += 1
170            self._check_results()
171
172    def _check_results(self):
173        """Adjusts parameters and positions based on outputs."""
174
175        # No prior minima found?
176        self._read_minima()
177        if len(self._minima) == 0:
178            self._log('msg', 'Found a new minimum.')
179            self._log('msg', 'Accepted new minimum.')
180            self._record_minimum()
181            self._log('par')
182            return
183        # Returned to starting position?
184        if self._previous_optimum:
185            compare = ComparePositions(translate=False)
186            dmax = compare(self._atoms, self._previous_optimum)
187            self._log('msg', 'Max distance to last minimum: %.3f A' % dmax)
188            if dmax < self._minima_threshold:
189                self._log('msg', 'Re-found last minimum.')
190                self._temperature *= self._beta1
191                self._log('par')
192                return
193        # In a previously found position?
194        unique, dmax_closest = self._unique_minimum_position()
195        self._log('msg', 'Max distance to closest minimum: %.3f A' %
196                  dmax_closest)
197        if not unique:
198            self._temperature *= self._beta2
199            self._log('msg', 'Found previously found minimum.')
200            self._log('par')
201            if self._previous_optimum:
202                self._log('msg', 'Restoring last minimum.')
203                self._atoms.positions = self._previous_optimum.positions
204            return
205        # Must have found a unique minimum.
206        self._temperature *= self._beta3
207        self._log('msg', 'Found a new minimum.')
208        self._log('par')
209        if (self._previous_energy is None or
210            (self._atoms.get_potential_energy() <
211                self._previous_energy + self._Ediff)):
212            self._log('msg', 'Accepted new minimum.')
213            self._Ediff *= self._alpha1
214            self._log('par')
215            self._record_minimum()
216        else:
217            self._log('msg', 'Rejected new minimum due to energy. '
218                             'Restoring last minimum.')
219            self._atoms.positions = self._previous_optimum.positions
220            self._Ediff *= self._alpha2
221            self._log('par')
222
223    def _log(self, cat='msg', message=None):
224        """Records the message as a line in the log file."""
225        if cat == 'init':
226            if world.rank == 0:
227                if os.path.exists(self._logfile):
228                    raise RuntimeError('File exists: %s' % self._logfile)
229            fd = paropen(self._logfile, 'w')
230            fd.write('par: %12s %12s %12s\n' % ('T (K)', 'Ediff (eV)',
231                                                'mdmin'))
232            fd.write('ene: %12s %12s %12s\n' % ('E_current', 'E_previous',
233                                                'Difference'))
234            fd.close()
235            return
236        fd = paropen(self._logfile, 'a')
237        if cat == 'msg':
238            line = 'msg: %s' % message
239        elif cat == 'par':
240            line = ('par: %12.4f %12.4f %12i' %
241                    (self._temperature, self._Ediff, self._mdmin))
242        elif cat == 'ene':
243            current = self._atoms.get_potential_energy()
244            if self._previous_optimum:
245                previous = self._previous_energy
246                line = ('ene: %12.5f %12.5f %12.5f' %
247                        (current, previous, current - previous))
248            else:
249                line = ('ene: %12.5f' % current)
250        fd.write(line + '\n')
251        fd.close()
252
253    def _optimize(self):
254        """Perform an optimization."""
255        self._atoms.set_momenta(np.zeros(self._atoms.get_momenta().shape))
256        with self._optimizer(self._atoms,
257                             trajectory='qn%05i.traj' % self._counter,
258                             logfile='qn%05i.log' % self._counter) as opt:
259            self._log('msg', 'Optimization: qn%05i' % self._counter)
260            opt.run(fmax=self._fmax)
261            self._log('ene')
262
263    def _record_minimum(self):
264        """Adds the current atoms configuration to the minima list."""
265        with io.Trajectory(self._minima_traj, 'a') as traj:
266            traj.write(self._atoms)
267        self._read_minima()
268        self._log('msg', 'Recorded minima #%i.' % (len(self._minima) - 1))
269
270    def _read_minima(self):
271        """Reads in the list of minima from the minima file."""
272        exists = os.path.exists(self._minima_traj)
273        if exists:
274            empty = os.path.getsize(self._minima_traj) == 0
275            if not empty:
276                with io.Trajectory(self._minima_traj, 'r') as traj:
277                    self._minima = [atoms for atoms in traj]
278            else:
279                self._minima = []
280            return True
281        else:
282            self._minima = []
283            return False
284
285    def _molecular_dynamics(self, resume=None):
286        """Performs a molecular dynamics simulation, until mdmin is
287        exceeded. If resuming, the file number (md%05i) is expected."""
288        self._log('msg', 'Molecular dynamics: md%05i' % self._counter)
289        mincount = 0
290        energies, oldpositions = [], []
291        thermalized = False
292        if resume:
293            self._log('msg', 'Resuming MD from md%05i.traj' % resume)
294            if os.path.getsize('md%05i.traj' % resume) == 0:
295                self._log('msg', 'md%05i.traj is empty. Resuming from '
296                          'qn%05i.traj.' % (resume, resume - 1))
297                atoms = io.read('qn%05i.traj' % (resume - 1), index=-1)
298            else:
299                with io.Trajectory('md%05i.traj' % resume, 'r') as images:
300                    for atoms in images:
301                        energies.append(atoms.get_potential_energy())
302                        oldpositions.append(atoms.positions.copy())
303                        passedmin = self._passedminimum(energies)
304                        if passedmin:
305                            mincount += 1
306                self._atoms.set_momenta(atoms.get_momenta())
307                thermalized = True
308            self._atoms.positions = atoms.get_positions()
309            self._log('msg', 'Starting MD with %i existing energies.' %
310                      len(energies))
311        if not thermalized:
312            MaxwellBoltzmannDistribution(self._atoms,
313                                         temperature_K=self._temperature,
314                                         force_temp=True)
315        traj = io.Trajectory('md%05i.traj' % self._counter, 'a',
316                             self._atoms)
317        dyn = VelocityVerlet(self._atoms, timestep=self._timestep * units.fs)
318        log = MDLogger(dyn, self._atoms, 'md%05i.log' % self._counter,
319                       header=True, stress=False, peratom=False)
320
321        with traj, dyn, log:
322            dyn.attach(log, interval=1)
323            dyn.attach(traj, interval=1)
324            while mincount < self._mdmin:
325                dyn.run(1)
326                energies.append(self._atoms.get_potential_energy())
327                passedmin = self._passedminimum(energies)
328                if passedmin:
329                    mincount += 1
330                oldpositions.append(self._atoms.positions.copy())
331            # Reset atoms to minimum point.
332            self._atoms.positions = oldpositions[passedmin[0]]
333
334    def _unique_minimum_position(self):
335        """Identifies if the current position of the atoms, which should be
336        a local minima, has been found before."""
337        unique = True
338        dmax_closest = 99999.
339        compare = ComparePositions(translate=True)
340        self._read_minima()
341        for minimum in self._minima:
342            dmax = compare(minimum, self._atoms)
343            if dmax < self._minima_threshold:
344                unique = False
345            if dmax < dmax_closest:
346                dmax_closest = dmax
347        return unique, dmax_closest
348
349
350class ComparePositions:
351    """Class that compares the atomic positions between two ASE atoms
352    objects. Returns the maximum distance that any atom has moved, assuming
353    all atoms of the same element are indistinguishable. If translate is
354    set to True, allows for arbitrary translations within the unit cell,
355    as well as translations across any periodic boundary conditions. When
356    called, returns the maximum displacement of any one atom."""
357
358    def __init__(self, translate=True):
359        self._translate = translate
360
361    def __call__(self, atoms1, atoms2):
362        atoms1 = atoms1.copy()
363        atoms2 = atoms2.copy()
364        if not self._translate:
365            dmax = self. _indistinguishable_compare(atoms1, atoms2)
366        else:
367            dmax = self._translated_compare(atoms1, atoms2)
368        return dmax
369
370    def _translated_compare(self, atoms1, atoms2):
371        """Moves the atoms around and tries to pair up atoms, assuming any
372        atoms with the same symbol are indistinguishable, and honors
373        periodic boundary conditions (for example, so that an atom at
374        (0.1, 0., 0.) correctly is found to be close to an atom at
375        (7.9, 0., 0.) if the atoms are in an orthorhombic cell with
376        x-dimension of 8. Returns dmax, the maximum distance between any
377        two atoms in the optimal configuration."""
378        atoms1.set_constraint()
379        atoms2.set_constraint()
380        for index in range(3):
381            assert atoms1.pbc[index] == atoms2.pbc[index]
382        least = self._get_least_common(atoms1)
383        indices1 = [atom.index for atom in atoms1 if atom.symbol == least[0]]
384        indices2 = [atom.index for atom in atoms2 if atom.symbol == least[0]]
385        # Make comparison sets from atoms2, which contain repeated atoms in
386        # all pbc's and bring the atom listed in indices2 to (0,0,0)
387        comparisons = []
388        repeat = []
389        for bc in atoms2.pbc:
390            if bc:
391                repeat.append(3)
392            else:
393                repeat.append(1)
394        repeated = atoms2.repeat(repeat)
395        moved_cell = atoms2.cell * atoms2.pbc
396        for moved in moved_cell:
397            repeated.translate(-moved)
398        repeated.set_cell(atoms2.cell)
399        for index in indices2:
400            comparison = repeated.copy()
401            comparison.translate(-atoms2[index].position)
402            comparisons.append(comparison)
403        # Bring the atom listed in indices1 to (0,0,0) [not whole list]
404        standard = atoms1.copy()
405        standard.translate(-atoms1[indices1[0]].position)
406        # Compare the standard to the comparison sets.
407        dmaxes = []
408        for comparison in comparisons:
409            dmax = self._indistinguishable_compare(standard, comparison)
410            dmaxes.append(dmax)
411        return min(dmaxes)
412
413    def _get_least_common(self, atoms):
414        """Returns the least common element in atoms. If more than one,
415        returns the first encountered."""
416        symbols = [atom.symbol for atom in atoms]
417        least = ['', np.inf]
418        for element in set(symbols):
419            count = symbols.count(element)
420            if count < least[1]:
421                least = [element, count]
422        return least
423
424    def _indistinguishable_compare(self, atoms1, atoms2):
425        """Finds each atom in atoms1's nearest neighbor with the same
426        chemical symbol in atoms2. Return dmax, the farthest distance an
427        individual atom differs by."""
428        atoms2 = atoms2.copy()  # allow deletion
429        atoms2.set_constraint()
430        dmax = 0.
431        for atom1 in atoms1:
432            closest = [np.nan, np.inf]
433            for index, atom2 in enumerate(atoms2):
434                if atom2.symbol == atom1.symbol:
435                    d = np.linalg.norm(atom1.position - atom2.position)
436                    if d < closest[1]:
437                        closest = [index, d]
438            if closest[1] > dmax:
439                dmax = closest[1]
440            del atoms2[closest[0]]
441        return dmax
442
443
444class PassedMinimum:
445    """Simple routine to find if a minimum in the potential energy surface
446    has been passed. In its default settings, a minimum is found if the
447    sequence ends with two downward points followed by two upward points.
448    Initialize with n_down and n_up, integer values of the number of up and
449    down points. If it has successfully determined it passed a minimum, it
450    returns the value (energy) of that minimum and the number of positions
451    back it occurred, otherwise returns None."""
452
453    def __init__(self, n_down=2, n_up=2):
454        self._ndown = n_down
455        self._nup = n_up
456
457    def __call__(self, energies):
458        if len(energies) < (self._nup + self._ndown + 1):
459            return None
460        status = True
461        index = -1
462        for i_up in range(self._nup):
463            if energies[index] < energies[index - 1]:
464                status = False
465            index -= 1
466        for i_down in range(self._ndown):
467            if energies[index] > energies[index - 1]:
468                status = False
469            index -= 1
470        if status:
471            return (-self._nup - 1), energies[-self._nup - 1]
472
473
474class MHPlot:
475    """Makes a plot summarizing the output of the MH algorithm from the
476    specified rundirectory. If no rundirectory is supplied, uses the
477    current directory."""
478
479    def __init__(self, rundirectory=None, logname='hop.log'):
480        if not rundirectory:
481            rundirectory = os.getcwd()
482        self._rundirectory = rundirectory
483        self._logname = logname
484        self._read_log()
485        self._fig, self._ax = self._makecanvas()
486        self._plot_data()
487
488    def get_figure(self):
489        """Returns the matplotlib figure object."""
490        return self._fig
491
492    def save_figure(self, filename):
493        """Saves the file to the specified path, with any allowed
494        matplotlib extension (e.g., .pdf, .png, etc.)."""
495        self._fig.savefig(filename)
496
497    def _read_log(self):
498        """Reads relevant parts of the log file."""
499        data = []  # format: [energy, status, temperature, ediff]
500
501        with open(os.path.join(self._rundirectory, self._logname), 'r') as fd:
502            lines = fd.read().splitlines()
503
504        step_almost_over = False
505        step_over = False
506        for line in lines:
507            if line.startswith('msg: Molecular dynamics:'):
508                status = 'performing MD'
509            elif line.startswith('msg: Optimization:'):
510                status = 'performing QN'
511            elif line.startswith('ene:'):
512                status = 'local optimum reached'
513                energy = floatornan(line.split()[1])
514            elif line.startswith('msg: Accepted new minimum.'):
515                status = 'accepted'
516                step_almost_over = True
517            elif line.startswith('msg: Found previously found minimum.'):
518                status = 'previously found minimum'
519                step_almost_over = True
520            elif line.startswith('msg: Re-found last minimum.'):
521                status = 'previous minimum'
522                step_almost_over = True
523            elif line.startswith('msg: Rejected new minimum'):
524                status = 'rejected'
525                step_almost_over = True
526            elif line.startswith('par: '):
527                temperature = floatornan(line.split()[1])
528                ediff = floatornan(line.split()[2])
529                if step_almost_over:
530                    step_over = True
531                    step_almost_over = False
532            if step_over:
533                data.append([energy, status, temperature, ediff])
534                step_over = False
535        if data[-1][1] != status:
536            data.append([np.nan, status, temperature, ediff])
537        self._data = data
538
539    def _makecanvas(self):
540        from matplotlib import pyplot
541        from matplotlib.ticker import ScalarFormatter
542        fig = pyplot.figure(figsize=(6., 8.))
543        lm, rm, bm, tm = 0.22, 0.02, 0.05, 0.04
544        vg1 = 0.01  # between adjacent energy plots
545        vg2 = 0.03  # between different types of plots
546        ratio = 2.  # size of an energy plot to a parameter plot
547        figwidth = 1. - lm - rm
548        totalfigheight = 1. - bm - tm - vg1 - 2. * vg2
549        parfigheight = totalfigheight / (2. * ratio + 2)
550        epotheight = ratio * parfigheight
551        ax1 = fig.add_axes((lm, bm, figwidth, epotheight))
552        ax2 = fig.add_axes((lm, bm + epotheight + vg1,
553                            figwidth, epotheight))
554        for ax in [ax1, ax2]:
555            ax.yaxis.set_major_formatter(ScalarFormatter(useOffset=False))
556        ediffax = fig.add_axes((lm, bm + 2. * epotheight + vg1 + vg2,
557                                figwidth, parfigheight))
558        tempax = fig.add_axes((lm, (bm + 2 * epotheight + vg1 + 2 * vg2 +
559                                    parfigheight), figwidth, parfigheight))
560        for ax in [ax2, tempax, ediffax]:
561            ax.set_xticklabels([])
562        ax1.set_xlabel('step')
563        tempax.set_ylabel('$T$, K')
564        ediffax.set_ylabel(r'$E_\mathrm{diff}$, eV')
565        for ax in [ax1, ax2]:
566            ax.set_ylabel(r'$E_\mathrm{pot}$, eV')
567        ax = CombinedAxis(ax1, ax2, tempax, ediffax)
568        self._set_zoomed_range(ax)
569        ax1.spines['top'].set_visible(False)
570        ax2.spines['bottom'].set_visible(False)
571        return fig, ax
572
573    def _set_zoomed_range(self, ax):
574        """Try to intelligently set the range for the zoomed-in part of the
575        graph."""
576        energies = [line[0] for line in self._data
577                    if not np.isnan(line[0])]
578        dr = max(energies) - min(energies)
579        if dr == 0.:
580            dr = 1.
581        ax.set_ax1_range((min(energies) - 0.2 * dr,
582                          max(energies) + 0.2 * dr))
583
584    def _plot_data(self):
585        for step, line in enumerate(self._data):
586            self._plot_energy(step, line)
587            self._plot_qn(step, line)
588            self._plot_md(step, line)
589        self._plot_parameters()
590        self._ax.set_xlim(self._ax.ax1.get_xlim())
591
592    def _plot_energy(self, step, line):
593        """Plots energy and annotation for acceptance."""
594        energy, status = line[0], line[1]
595        if np.isnan(energy):
596            return
597        self._ax.plot([step, step + 0.5], [energy] * 2, '-',
598                      color='k', linewidth=2.)
599        if status == 'accepted':
600            self._ax.text(step + 0.51, energy, r'$\checkmark$')
601        elif status == 'rejected':
602            self._ax.text(step + 0.51, energy, r'$\Uparrow$', color='red')
603        elif status == 'previously found minimum':
604            self._ax.text(step + 0.51, energy, r'$\hookleftarrow$',
605                          color='red', va='center')
606        elif status == 'previous minimum':
607            self._ax.text(step + 0.51, energy, r'$\leftarrow$',
608                          color='red', va='center')
609
610    def _plot_md(self, step, line):
611        """Adds a curved plot of molecular dynamics trajectory."""
612        if step == 0:
613            return
614        energies = [self._data[step - 1][0]]
615        file = os.path.join(self._rundirectory, 'md%05i.traj' % step)
616        with io.Trajectory(file, 'r') as traj:
617            for atoms in traj:
618                energies.append(atoms.get_potential_energy())
619        xi = step - 1 + .5
620        if len(energies) > 2:
621            xf = xi + (step + 0.25 - xi) * len(energies) / (len(energies) - 2.)
622        else:
623            xf = step
624        if xf > (step + .75):
625            xf = step
626        self._ax.plot(np.linspace(xi, xf, num=len(energies)), energies,
627                      '-k')
628
629    def _plot_qn(self, index, line):
630        """Plots a dashed vertical line for the optimization."""
631        if line[1] == 'performing MD':
632            return
633        file = os.path.join(self._rundirectory, 'qn%05i.traj' % index)
634        if os.path.getsize(file) == 0:
635            return
636        with io.Trajectory(file, 'r') as traj:
637            energies = [traj[0].get_potential_energy(),
638                        traj[-1].get_potential_energy()]
639        if index > 0:
640            file = os.path.join(self._rundirectory, 'md%05i.traj' % index)
641            atoms = io.read(file, index=-3)
642            energies[0] = atoms.get_potential_energy()
643        self._ax.plot([index + 0.25] * 2, energies, ':k')
644
645    def _plot_parameters(self):
646        """Adds a plot of temperature and Ediff to the plot."""
647        steps, Ts, ediffs = [], [], []
648        for step, line in enumerate(self._data):
649            steps.extend([step + 0.5, step + 1.5])
650            Ts.extend([line[2]] * 2)
651            ediffs.extend([line[3]] * 2)
652        self._ax.tempax.plot(steps, Ts)
653        self._ax.ediffax.plot(steps, ediffs)
654
655        for ax in [self._ax.tempax, self._ax.ediffax]:
656            ylim = ax.get_ylim()
657            yrange = ylim[1] - ylim[0]
658            ax.set_ylim((ylim[0] - 0.1 * yrange, ylim[1] + 0.1 * yrange))
659
660
661def floatornan(value):
662    """Converts the argument into a float if possible, np.nan if not."""
663    try:
664        output = float(value)
665    except ValueError:
666        output = np.nan
667    return output
668
669
670class CombinedAxis:
671    """Helper class for MHPlot to plot on split y axis and adjust limits
672    simultaneously."""
673
674    def __init__(self, ax1, ax2, tempax, ediffax):
675        self.ax1 = ax1
676        self.ax2 = ax2
677        self.tempax = tempax
678        self.ediffax = ediffax
679        self._ymax = -np.inf
680
681    def set_ax1_range(self, ylim):
682        self._ax1_ylim = ylim
683        self.ax1.set_ylim(ylim)
684
685    def plot(self, *args, **kwargs):
686        self.ax1.plot(*args, **kwargs)
687        self.ax2.plot(*args, **kwargs)
688        # Re-adjust yrange
689        for yvalue in args[1]:
690            if yvalue > self._ymax:
691                self._ymax = yvalue
692        self.ax1.set_ylim(self._ax1_ylim)
693        self.ax2.set_ylim((self._ax1_ylim[1], self._ymax))
694
695    def set_xlim(self, *args):
696        self.ax1.set_xlim(*args)
697        self.ax2.set_xlim(*args)
698        self.tempax.set_xlim(*args)
699        self.ediffax.set_xlim(*args)
700
701    def text(self, *args, **kwargs):
702        y = args[1]
703        if y < self._ax1_ylim[1]:
704            ax = self.ax1
705        else:
706            ax = self.ax2
707        ax.text(*args, **kwargs)
708