1import multiprocessing as mp
2import os
3import re
4import sys
5import time
6from collections import defaultdict
7
8import numpy as np
9from scipy.optimize import differential_evolution as DE
10from ase import Atoms
11from ase.data import covalent_radii, atomic_numbers
12from ase.units import Bohr, Ha
13
14from gpaw import GPAW, PW, setup_paths, Mixer, ConvergenceError
15from gpaw.atom.generator2 import generate  # , DatasetGenerationError
16from gpaw.atom.aeatom import AllElectronAtom
17from gpaw.atom.atompaw import AtomPAW
18from gpaw.setup import create_setup
19
20
21my_covalent_radii = covalent_radii.copy()
22for e in ['Bk', 'Cf', 'Es', 'Fm', 'Md', 'No', 'Lr']:  # missing radii
23    my_covalent_radii[atomic_numbers[e]] = 1.7
24
25my_radii = {
26    1: 0.41, 2: 1.46, 3: 1.47, 4: 1.08, 5: 0.83, 6: 0.75, 7: 0.67,
27    8: 0.68, 9: 0.7, 10: 1.63, 11: 1.7, 12: 1.47, 13: 1.4, 14: 1.18,
28    15: 1.07, 16: 1.15, 17: 1.06, 19: 2.09, 20: 1.71, 21: 1.58, 22: 1.44,
29    23: 1.32, 24: 1.28, 25: 1.31, 26: 1.28, 27: 1.28, 28: 1.28, 29: 1.29,
30    30: 1.28, 31: 1.26, 32: 1.28, 33: 1.26, 34: 1.18, 35: 1.19, 37: 2.24,
31    38: 1.94, 39: 1.73, 40: 1.6, 41: 1.45, 42: 1.39, 43: 1.31, 44: 1.32,
32    45: 1.34, 46: 1.39, 47: 1.52, 48: 1.53, 49: 1.57, 50: 1.45, 51: 1.43,
33    52: 1.45, 53: 1.42, 56: 1.95, 57: 1.89, 58: 1.84, 59: 1.91, 60: 1.88,
34    61: 1.61, 62: 1.85, 63: 1.83, 64: 1.82, 65: 1.76, 66: 1.74, 67: 1.73,
35    68: 1.72, 69: 1.71, 70: 1.7, 71: 1.69, 72: 1.57, 73: 1.45, 74: 1.39,
36    75: 1.34, 76: 1.35, 77: 1.35, 78: 1.39, 79: 1.38, 80: 1.7, 81: 2.27,
37    82: 1.81, 83: 1.59, 84: 1.69, 88: 2.06, 89: 1.95, 90: 1.92}
38
39
40class PAWDataError(Exception):
41    """Error in PAW-data generation."""
42
43
44class DatasetOptimizer:
45    tolerances = np.array([0.2,  # radii
46                           0.3,  # log. derivs.
47                           40,  # iterations
48                           1.2 * 2 / 3 * 300 * 0.1**0.25,  # convergence
49                           0.0005,  # eggbox error
50                           0.05])  # IP
51
52    def __init__(self, symbol='H', nc=False):
53        self.old = False
54
55        self.symbol = symbol
56        self.nc = nc
57
58        line = Path('start.txt').read_text()
59        words = line.split()
60        assert words[1] == symbol
61        projectors = words[3]
62        radii = [float(f) for f in words[5].split(',')]
63        r0 = float(words[7].split(',')[1])
64
65        self.Z = atomic_numbers[symbol]
66        rc = self.rc = my_radii[self.Z] / Bohr
67
68        # Parse projectors string:
69        pattern = r'(-?\d+\.\d)'
70        energies = []
71        for m in re.finditer(pattern, projectors):
72            energies.append(float(projectors[m.start():m.end()]))
73        self.projectors = re.sub(pattern, '{:.1f}', projectors)
74        self.nenergies = len(energies)
75
76        self.x = energies + radii + [r0]
77        self.bounds = ([(-1.0, 4.0)] * self.nenergies +
78                       [(rc * 0.7, rc * 1.0) for r in radii] +
79                       [(0.3, rc)])
80
81        self.ecut1 = 450.0
82        self.ecut2 = 800.0
83
84        setup_paths[:0] = ['.']
85
86        self.logfile = None
87        self.tflush = time.time() + 60
88
89    def run(self):
90        print(self.symbol, self.rc / Bohr, self.projectors)
91        print(self.x)
92        print(self.bounds)
93        init = 'latinhypercube'
94        if 0:  # os.path.isfile('data.csv'):
95            n = len(self.x)
96            data = self.read()[:15 * n]
97            if np.isfinite(data[:, n]).all() and len(data) == 15 * n:
98                init = data[:, :n]
99
100        DE(self, self.bounds, init=init, workers=8, updating='deferred')
101
102    def generate(self, fd, projectors, radii, r0, xc,
103                 scalar_relativistic=True, tag=None, logderivs=True):
104
105        if projectors[-1].isupper():
106            nderiv0 = 5
107        else:
108            nderiv0 = 2
109
110        type = 'poly'
111        if self.nc:
112            type = 'nc'
113
114        try:
115            gen = generate(self.symbol, projectors, radii, r0, nderiv0,
116                           xc, scalar_relativistic, (type, 4), output=fd)
117        except np.linalg.LinAlgError:
118            raise PAWDataError('LinAlgError')
119
120        if not scalar_relativistic:
121            if not gen.check_all():
122                raise PAWDataError('dataset check failed')
123
124        if tag is not None:
125            gen.make_paw_setup(tag or None).write_xml()
126
127        r = 1.1 * gen.rcmax
128
129        lmax = 2
130        if 'f' in projectors:
131            lmax = 3
132
133        error = 0.0
134        if logderivs:
135            for l in range(lmax + 1):
136                emin = -1.5
137                emax = 2.0
138                n0 = gen.number_of_core_states(l)
139                if n0 > 0:
140                    e0_n = gen.aea.channels[l].e_n
141                    emin = max(emin, e0_n[n0 - 1] + 0.1)
142                energies = np.linspace(emin, emax, 100)
143                de = energies[1] - energies[0]
144                ld1 = gen.aea.logarithmic_derivative(l, energies, r)
145                ld2 = gen.logarithmic_derivative(l, energies, r)
146                error += abs(ld1 - ld2).sum() * de
147
148        return error
149
150    def parameters(self, x):
151        energies = x[:self.nenergies]
152        radii = x[self.nenergies:-1]
153        r0 = x[-1]
154        projectors = self.projectors.format(*energies)
155        return energies, radii, r0, projectors
156
157    def __call__(self, x):
158        id = mp.current_process().name[-1]
159        self.setup = 'de' + id
160        energies, radii, r0, projectors = self.parameters(x)
161
162        self.logfile = open(f'data-{id}.csv', 'a')
163
164        fd = open(f'out-{id}.txt', 'w')
165        errors, msg, convenergies, eggenergies, ips = \
166            self.test(fd, projectors, radii, r0)
167        error = ((errors / self.tolerances)**2).sum()
168
169        if msg:
170            print(msg, x, error, errors, convenergies, eggenergies, ips,
171                  file=sys.stderr)
172
173        convenergies += [0] * (7 - len(convenergies))
174
175        print(', '.join(repr(number) for number in
176                        list(x) + [error] + errors +
177                        convenergies + eggenergies + ips),
178              file=self.logfile)
179
180        if time.time() > self.tflush:
181            self.logfile.flush()
182            self.tflush = time.time() + 60
183
184        return error
185
186    def test_old_paw_data(self):
187        fd = open('old.txt', 'w')
188        area, niter, convenergies = self.convergence(fd, 'paw')
189        eggenergies = self.eggbox(fd, 'paw')
190        print('RESULTS:',
191              ', '.join(repr(number) for number in
192                        [area, niter, max(eggenergies)] +
193                        convenergies + eggenergies),
194              file=fd)
195
196    def test(self, fd, projectors, radii, r0):
197        errors = [np.inf] * 6
198        energies = []
199        eggenergies = [0, 0, 0]
200        ip = 0.0
201        ip0 = 0.0
202        msg = ''
203
204        try:
205            if any(r < r0 for r in radii):
206                raise PAWDataError('Core radius too large')
207
208            rc = self.rc
209            errors[0] = sum(r - rc for r in radii if r > rc)
210
211            error = 0.0
212            for kwargs in [dict(xc='PBE', tag=self.setup),
213                           dict(xc='PBE', scalar_relativistic=False),
214                           dict(xc='LDA', tag=self.setup),
215                           dict(xc='PBEsol'),
216                           dict(xc='RPBE'),
217                           dict(xc='PW91')]:
218                error += self.generate(fd, projectors, radii, r0, **kwargs)
219            errors[1] = error
220
221            area, niter, energies = self.convergence(fd)
222            errors[2] = niter
223            errors[3] = area
224
225            eggenergies = self.eggbox(fd)
226            errors[4] = max(eggenergies)
227
228            ip, ip0 = self.ip(fd)
229            errors[5] = ip - ip0
230
231        except (ConvergenceError, PAWDataError, RuntimeError,
232                np.linalg.LinAlgError) as e:
233            msg = str(e)
234
235        return errors, msg, energies, eggenergies, [ip, ip0]
236
237    def eggbox(self, fd, setup='de'):
238        energies = []
239        for h in [0.16, 0.18, 0.2]:
240            a0 = 16 * h
241            atoms = Atoms(self.symbol, cell=(a0, a0, 2 * a0), pbc=True)
242            if 58 <= self.Z <= 70 or 90 <= self.Z <= 102:
243                M = 999
244                mixer = {'mixer': Mixer(0.01, 5)}
245            else:
246                M = 333
247                mixer = {}
248            atoms.calc = GPAW(h=h,
249                              xc='PBE',
250                              symmetry='off',
251                              setups=self.setup,
252                              maxiter=M,
253                              txt=fd,
254                              **mixer)
255            atoms.positions += h / 2  # start with broken symmetry
256            e0 = atoms.get_potential_energy()
257            atoms.positions -= h / 6
258            e1 = atoms.get_potential_energy()
259            atoms.positions -= h / 6
260            e2 = atoms.get_potential_energy()
261            atoms.positions -= h / 6
262            e3 = atoms.get_potential_energy()
263            energies.append(np.ptp([e0, e1, e2, e3]))
264        # print(energies)
265        return energies
266
267    def convergence(self, fd, setup='de'):
268        a = 3.0
269        atoms = Atoms(self.symbol, cell=(a, a, a), pbc=True)
270        if 58 <= self.Z <= 70 or 90 <= self.Z <= 102:
271            M = 999
272            mixer = {'mixer': Mixer(0.01, 5)}
273        else:
274            M = 333
275            mixer = {}
276        atoms.calc = GPAW(mode=PW(1500),
277                          xc='PBE',
278                          setups=self.setup,
279                          symmetry='off',
280                          maxiter=M,
281                          txt=fd,
282                          **mixer)
283        e0 = atoms.get_potential_energy()
284        energies = [e0]
285        iters = atoms.calc.get_number_of_iterations()
286        oldfde = None
287        area = 0.0
288
289        def f(x):
290            return x**0.25
291
292        for ec in range(800, 200, -100):
293            atoms.calc.set(mode=PW(ec))
294            atoms.calc.set(eigensolver='rmm-diis')
295            de = atoms.get_potential_energy() - e0
296            energies.append(de)
297            # print(ec, de)
298            fde = f(abs(de))
299            if fde > f(0.1):
300                if oldfde is None:
301                    return np.inf, iters, energies
302                ec0 = ec + (fde - f(0.1)) / (fde - oldfde) * 100
303                area += ((ec + 100) - ec0) * (f(0.1) + oldfde) / 2
304                break
305
306            if oldfde is not None:
307                area += 100 * (fde + oldfde) / 2
308            oldfde = fde
309
310        return area, iters, energies
311
312    def ip(self, fd):
313        IP, IP0 = ip(self.symbol, fd, self.setup)
314        return IP, IP0
315
316    def read(self):
317        data = []
318        for i in '12345678s':
319            try:
320                d = np.loadtxt(f'data-{i}.csv', delimiter=',')
321            except OSError:
322                if i == 's':
323                    pass
324                if i == '1':
325                    data = np.loadtxt('data.csv', delimiter=',')
326                    break
327            data += d.tolist()
328        data = np.array(data)
329        return data[data[:, len(self.x)].argsort()]
330
331    def summary(self, N=10):
332        n = len(self.x)
333        for x in self.read()[:N]:
334            print('{:3} {:2} {:9.1f} ({}) ({}, {})'
335                  .format(self.Z,
336                          self.symbol,
337                          x[n],
338                          ', '.join('{:4.1f}'.format(e) + s
339                                    for e, s
340                                    in zip(x[n + 1:n + 7] / self.tolerances,
341                                           'rlciex')),
342                          ', '.join('{:+.2f}'.format(e)
343                                    for e in x[:self.nenergies]),
344                          ', '.join('{:.2f}'.format(r)
345                                    for r in x[self.nenergies:n])))
346
347    def best(self):
348        n = len(self.x)
349        a = self.read()[0]
350        x = a[:n]
351        error = a[n]
352        energies, radii, r0, projectors = self.parameters(x)
353        if 1:
354            if projectors[-1].isupper():
355                nderiv0 = 5
356            else:
357                nderiv0 = 2
358            fmt = '{:3} {:2} -P {:31} -r {:20} -0 {},{:.2f} # {:10.3f}'
359            print(fmt.format(self.Z,
360                             self.symbol,
361                             projectors,
362                             ','.join('{0:.2f}'.format(r) for r in radii),
363                             nderiv0,
364                             r0,
365                             error))
366        if 1 and error != np.inf and error != np.nan:
367            self.generate(None, projectors, radii, r0, 'PBE', True, 'a7o',
368                          logderivs=False)
369
370
371def ip(symbol, fd, setup):
372    xc = 'LDA'
373    aea = AllElectronAtom(symbol, log=fd)
374    aea.initialize()
375    aea.run()
376    aea.refine()
377    aea.scalar_relativistic = True
378    aea.refine()
379    energy = aea.ekin + aea.eH + aea.eZ + aea.exc
380    eigs = []
381    for l, channel in enumerate(aea.channels):
382        n = l + 1
383        for e, f in zip(channel.e_n, channel.f_n):
384            if f == 0:
385                break
386            eigs.append((e, n, l))
387            n += 1
388    e0, n0, l0 = max(eigs)
389    aea = AllElectronAtom(symbol, log=fd)
390    aea.add(n0, l0, -1)
391    aea.initialize()
392    aea.run()
393    aea.refine()
394    aea.scalar_relativistic = True
395    aea.refine()
396    IP = aea.ekin + aea.eH + aea.eZ + aea.exc - energy
397    IP *= Ha
398
399    s = create_setup(symbol, type=setup, xc=xc)
400    f_ln = defaultdict(list)
401    for l, f in zip(s.l_j, s. f_j):
402        if f:
403            f_ln[l].append(f)
404
405    f_sln = [[f_ln[l] for l in range(1 + max(f_ln))]]
406    calc = AtomPAW(symbol, f_sln, xc=xc, txt=fd, setups=setup)
407    energy = calc.results['energy']
408    # eps_n = calc.wfs.kpt_u[0].eps_n
409
410    f_sln[0][l0][-1] -= 1
411    calc = AtomPAW(symbol, f_sln, xc=xc, charge=1, txt=fd, setups=setup)
412    IP2 = calc.results['energy'] - energy
413    return IP, IP2
414
415
416if __name__ == '__main__':
417    import argparse
418    from pathlib import Path
419    parser = argparse.ArgumentParser(usage='python -m gpaw.atom.optimize '
420                                     '[options] [folder folder ...]',
421                                     description='Optimize PAW data')
422    parser.add_argument('-s', '--summary', type=int)
423    parser.add_argument('-b', '--best', action='store_true')
424    parser.add_argument('-n', '--norm-conserving', action='store_true')
425    parser.add_argument('-o', '--old-setups', action='store_true')
426    parser.add_argument('folder', nargs='*')
427    args = parser.parse_args()
428    folders = [Path(folder) for folder in args.folder or ['.']]
429    home = Path.cwd()
430    for folder in folders:
431        try:
432            os.chdir(folder)
433            symbol = Path.cwd().name
434            do = DatasetOptimizer(symbol)
435            if args.summary:
436                do.summary(args.summary)
437            elif args.old_setups:
438                do.test_old_paw_data()
439            elif args.best:
440                do.best()
441            else:
442                do.run()
443        finally:
444            os.chdir(home)
445