1import numpy as np
2import pytest
3
4from ase.lattice.hexagonal import Graphene
5from ase.parallel import parprint as pp
6
7from gpaw import GPAW
8from gpaw.response.df import DielectricFunction
9from gpaw.mpi import world
10
11
12@pytest.mark.skip(reason='TODO')
13def test_graphene_EELS():
14    system = Graphene(symbol='C',
15                      latticeconstant={'a': 2.45, 'c': 1.0},
16                      size=(1, 1, 1))
17    system.pbc = (1, 1, 0)
18    system.center(axis=2, vacuum=4.0)
19
20    nkpts = 5
21
22    communicator = world.new_communicator(np.array([world.rank]))
23    gpwname = 'dump.graphene.gpw'
24
25    if world.rank == 0:
26        calc = GPAW(mode='pw',
27                    kpts=(nkpts, nkpts, 1),
28                    communicator=communicator,
29                    xc='oldLDA',
30                    nbands=len(system) * 6,
31                    txt='gpaw.graphene.txt')
32        system.calc = calc
33        system.get_potential_energy()
34        calc.write(gpwname, mode='all')
35
36    world.barrier()
37
38    parallel = dict(domain=(1, 1, 1), band=1)
39    if world.size == 8:
40        # parallel['domain'] = (1, 1, 2)
41        parallel['band'] = 2
42    calc = GPAW(gpwname,
43                txt=None,
44                parallel=parallel,
45                idiotproof=False)
46    pp('after restart')
47
48    q = np.array([1.0 / nkpts, 0., 0.])
49    w = np.linspace(0, 31.9, 320)
50    dw = w[1] - w[0]
51
52    def getpeak(energies, loss):
53        arg = loss.argmax()
54        energy = energies[arg]
55        peakloss = loss[arg]
56        return energy, peakloss
57
58    scriptlines = []
59
60    loss_errs = []
61    energy_errs = []
62
63    def check(name, energy, peakloss, ref_energy, ref_loss):
64        pp('check %s :: energy = %5.2f [%5.2f], peakloss = %.12f [%.12f]'
65           % (name, energy, ref_energy, peakloss, ref_loss))
66        energy_errs.append(abs(energy - ref_energy))
67        loss_errs.append(abs(peakloss - ref_loss))
68
69    template = """\
70    check_df('%s', %s, %s, %s, %s,
71             **%s)"""
72
73    def check_df(name, ref_energy, ref_loss, ref_energy_lfe, ref_loss_lfe,
74                 **kwargs_override):
75        kwargs = dict(calc=calc, frequencies=w.copy(), eta=0.5, ecut=30,
76                      txt='df.%s.txt' % name)
77        kwargs.update(kwargs_override)
78        df = DielectricFunction(**kwargs)
79        fname = 'dfdump.%s.dat' % name
80        df.get_eels_spectrum('RPA', q_c=q, filename=fname)
81        world.barrier()
82        d = np.loadtxt(fname, delimiter=',')
83
84        loss = d[:, 1]
85        loss_lfe = d[:, 2]
86        energies = d[:, 0]
87
88        # import pylab as pl
89        # fig = pl.figure()
90        # ax1 = fig.add_subplot(111)
91        # ax1.plot(d[:, 0], d[:, 1]/np.max(d[:, 1]))
92        # ax1.plot(d[:, 0], d[:, 2]/np.max(d[:, 2]))
93        # ax1.axis(ymin=0, ymax=1)
94        # fig.savefig('fig.%s.pdf' % name)
95
96        energy, peakloss = getpeak(energies, loss)
97        energy_lfe, peakloss_lfe = getpeak(energies, loss_lfe)
98
99        check(name, energy, peakloss, ref_energy, ref_loss)
100        check('%s-lfe' % name, energy_lfe, peakloss_lfe, ref_energy_lfe,
101              ref_loss_lfe)
102
103        line = template % (name, energy, peakloss, energy_lfe, peakloss_lfe,
104                           repr(kwargs_override))
105        scriptlines.append(line)
106
107    # These lines can be generated by the loop over 'scriptlines' below,
108    # in case new reference values are wanted
109
110    # The implementation for vcut choices has still to be done
111
112    check_df('3d', 20.20, 2.505295593820, 26.90, 1.748517033160)
113    #         **{'rpad': array([1, 1, 1])}) #, 'vcut': '3D'})
114    # check_df('2d', 20.10, 2.449662058530, 26.80, 1.538080502420,
115    #          **{'rpad': array([1, 1, 1]), 'vcut': '2D'})
116
117    pp()
118    pp('Insert lines into script to set new reference values:')
119    for line in scriptlines:
120        pp(line)
121    pp()
122
123    for err in energy_errs:
124        # with the current grid this error just means
125        assert err < dw / 4.0, err
126    for err in loss_errs:
127        assert err < 1e-6, err
128