1# coding: utf-8
2"""Work subclasses related to DFTP."""
3
4from .works import Work, MergeDdb
5
6
7class ElasticWork(Work, MergeDdb):
8    """
9    This Work computes the elastic constants and (optionally) the piezoelectric tensor.
10    It consists of Response function calculations for:
11
12        * rigid-atom elastic tensor
13        * rigid-atom piezoelectric tensor
14        * interatomic force constants at gamma
15        * Born effective charges
16
17    The structure is assumed to be already relaxed
18
19    Create a `Flow` for phonon calculations. The flow has one works with:
20
21        - 1 GS Task
22        - 3 DDK Task
23        - 4 Phonon Tasks (Gamma point)
24        - 6 Elastic tasks (3 uniaxial + 3 shear strain)
25
26    The Phonon tasks and the elastic task will read the DDK produced at the beginning
27    """
28    @classmethod
29    def from_scf_input(cls, scf_input, with_relaxed_ion=True, with_piezo=False, with_dde=False,
30                       tolerances=None, den_deps=None, manager=None):
31        """
32        Args:
33            scf_input:
34            with_relaxed_ion:
35            with_piezo:
36            with_dde: Compute electric field perturbations.
37            tolerances: Dict of tolerances
38            den_deps:
39            manager:
40
41        Similar to `from_scf_task`, the difference is that this method requires
42        an input for SCF calculation instead of a ScfTask. All the tasks (Scf + Phonon)
43        are packed in a single Work whereas in the previous case we usually have multiple works.
44        """
45        if tolerances is None: tolerances = {}
46        new = cls(manager=manager)
47
48        # Register task for WFK0 calculation (either SCF or NCSCF if den_deps is given)
49        if den_deps is None:
50            wfk_task = new.register_scf_task(scf_input)
51        else:
52            tolwfr = 1.0e-20
53            if "nscf" in tolerances:
54                tolwfr = tolerances["nscf"]["tolwfr"]
55            nscf_input = scf_input.new_with_vars(iscf=-2, tolwfr=tolwfr)
56            wfk_task = new.register_nscf_task(nscf_input, deps=den_deps)
57
58        if with_piezo or with_dde:
59            # Calculate the ddk wf's needed for piezoelectric tensor and Born effective charges.
60            #ddk_tolerance = {"tolwfr": 1.0e-20}
61            ddk_tolerance = tolerances.get("ddk", None)
62            ddk_multi = scf_input.make_ddk_inputs(tolerance=ddk_tolerance, manager=manager)
63            ddk_tasks = []
64            for inp in ddk_multi:
65                ddk_task = new.register_ddk_task(inp, deps={wfk_task: "WFK"})
66                ddk_tasks.append(ddk_task)
67            ddk_deps = {ddk_task: "DDK" for ddk_task in ddk_tasks}
68
69        if with_dde:
70            # Add tasks for electric field perturbation.
71            #dde_tolerance = None
72            dde_tolerance = tolerances.get("dde", None)
73            dde_multi = scf_input.make_dde_inputs(tolerance=dde_tolerance, use_symmetries=True, manager=manager)
74            dde_deps = {wfk_task: "WFK"}
75            dde_deps.update(ddk_deps)
76            for inp in dde_multi:
77                new.register_dde_task(inp, deps=dde_deps)
78
79        # Build input files for strain and (optionally) phonons.
80        #strain_tolerance = {"tolvrs": 1e-10}
81        strain_tolerance = tolerances.get("strain", None)
82        strain_multi = scf_input.make_strain_perts_inputs(tolerance=strain_tolerance, manager=manager,
83            phonon_pert=with_relaxed_ion, kptopt=2)
84
85        if with_relaxed_ion:
86            # Phonon perturbation (read DDK if piezo).
87            ph_deps = {wfk_task: "WFK"}
88            if with_piezo: ph_deps.update(ddk_deps)
89            for inp in strain_multi:
90                if inp.get("rfphon", 0) == 1:
91                    new.register_phonon_task(inp, deps=ph_deps)
92
93        # Finally compute strain pertubations (read DDK if piezo).
94        elast_deps = {wfk_task: "WFK"}
95        if with_piezo: elast_deps.update(ddk_deps)
96        for inp in strain_multi:
97            if inp.get("rfstrs", 0) != 0:
98                new.register_elastic_task(inp, deps=elast_deps)
99
100        return new
101
102    def on_all_ok(self):
103        """
104        This method is called when all the tasks of the Work reach S_OK.
105        Ir runs `mrgddb` in sequential on the local machine to produce
106        the final DDB file in the outdir of the `Work`.
107        """
108        # Merge DDB files.
109        out_ddb = self.merge_ddb_files(delete_source_ddbs=False, only_dfpt_tasks=False)
110        results = self.Results(node=self, returncode=0, message="DDB merge done")
111
112        return results
113
114
115class NscfDdksWork(Work):
116    """
117    This work requires a DEN file and computes the KS energies with a non self-consistent task
118    with a dense k-mesh and empty states.
119    This task is then followed by the computation of the DDK matrix elements with nstep = 1
120    (the first order change of the wavefunctions is not converged but we only need the matrix elements)
121    Mainly used to prepare optic calculations or other post-processing steps requiring the DDKs.
122    """
123
124    @classmethod
125    def from_scf_task(cls, scf_task, ddk_ngkpt, ddk_shiftk, ddk_nband, manager=None):
126        """
127        Build NscfDdksWork from a scf_task.
128        Args:
129            scf_task: GS task. Must produce the DEN file required for the NSCF run.
130            ddk_ngkpt: k-mesh used for the NSCF run and the non self-consistent DDK tasks.
131            ddk_shiftk: k-mesh shifts
132            ddk_nband: Number of bands (occupied + empty) used in the NSCF task and the DDKs tasks.
133            manager: TaskManager instance. Use default if None.
134        Return: NscfDdksWork instance
135        """
136        new = cls(manager=manager)
137
138        # NSCF task with nband states and points in the IBZ (note kptopt = 1)
139        nscf_inp0 = scf_task.input.deepcopy()
140        nscf_inp0.set_vars(nband=ddk_nband, prtwf=1)
141        nscf_inp0.set_kmesh(ddk_ngkpt, ddk_shiftk, kptopt=1)
142        nscf_task0 = new.register_nscf_task(nscf_inp0, deps={scf_task: "DEN"})
143
144        # NSCF run with nband states and points in the IBZ defined by time-reversal only (as required by DDK)
145        # This is gonna be quick because Abinit will symmetrize states from the previous WFK file.
146        # Time-reversal symmetry can be used in optic.
147        #nscf_inp1 = nscf_inp0.deepcopy()
148        #nscf_inp0.set_kmesh(ddk_ngkpt, ddk_shiftk, kptopt=2)
149        #nscf_task1 = new.register_nscf_task(nscf_inp1)
150
151        # This is the task producing the KS energies for optic
152        new.task_with_ks_energies = nscf_task0
153
154        # Build task for one-shot DDKs (note kptopt 2)
155        ddk_inputs = nscf_inp0.make_ddk_inputs(kptopt=2)
156        new.ddk_tasks = []
157        for ddk_inp in ddk_inputs:
158            # FIXME: prtwfk should be set to 0 but need to replace DDK.nc
159            ddk_inp.set_vars(nstep=1, nline=0, prtwf=1)
160            #new.register_ddk_task(ddk_inp, deps={nscf_task0: "WFK"})
161            # FIXME: Here I have a conflict with DDK.nc and DDK
162            t = new.register_task(ddk_inp, deps={nscf_task0: "WFK"})
163            new.ddk_tasks.append(t)
164
165        return new
166