1# -*- coding: utf-8 -*-
2#
3# Copyright (c) 2017, the cclib development team
4#
5# This file is part of cclib (http://cclib.github.io) and is distributed under
6# the terms of the BSD 3-Clause License.
7
8"""A writer for wfx format files."""
9
10import os.path
11import numpy
12
13from cclib.io import filewriter
14from cclib.parser import utils
15
16
17# Number of orbitals of type key.
18# There are 3 p type, 6 d type orbitals etc.
19ORBITAL_COUNT = {'S':1, 'P':3, 'D':6, 'F':10, 'G':15, 'H':21}
20
21# Index of first orbital of type key in a list of orbitals.
22# The first s orbital has index 1, first p orbital has index 2, and first d
23# has index 5.
24ORBITAL_INDICES = {'S': 1}
25ORBITAL_NAMES = 'SPDFGH'
26for idx, name in enumerate(ORBITAL_NAMES[1:], start=1):
27    prev_orbital_name = ORBITAL_NAMES[idx - 1]
28    prev_orbital_count = ORBITAL_COUNT[prev_orbital_name]
29    prev_orbital_index = ORBITAL_INDICES[prev_orbital_name]
30    ORBITAL_INDICES[name] = prev_orbital_count + prev_orbital_index
31
32PI_CUBE_INV = (2.0 / numpy.pi) ** 3
33
34# Float formatting template.
35WFX_FIELD_FMT = '%22.11E'
36
37# Precomputed values for l+m+n to be used in MO normalization.
38_L = dict(
39    [(prim_type, 0) for prim_type in range(1, 2)] +   # s
40    [(prim_type, 1) for prim_type in range(2, 5)] +   # p
41    [(prim_type, 2) for prim_type in range(5, 11)] +  # d
42    [(prim_type, 3) for prim_type in range(11, 21)] + # f
43    [(prim_type, 4) for prim_type in range(21, 36)]   # g
44)
45
46# Precomputed values for ((2l-1)!! * (2m-1)!! * (2n-1)!!).
47_M = dict(
48    [(L, 1) for L in range(1, 5)] +
49    [(L, 9) for L in range(5, 8)] +
50    [(L, 1) for L in range(8, 11)] +
51    [(L, 225) for L in range(11, 14)] +
52    [(L, 9) for L in range(14, 20)] +
53    [(L, 1) for L in range(20, 21)] +
54    [(L, 11025) for L in range(21, 24)] +
55    [(L, 225) for L in range(24, 30)] +
56    [(L, 81) for L in range(30, 33)] +
57    [(L, 9) for L in range(33, 36)]
58)
59
60
61def _section(section_name, section_data):
62    """Add opening/closing section_name tags to data."""
63    opening_tag = ['<' + section_name + '>']
64    closing_tag = ['</' + section_name + '>']
65
66    section = None
67    if isinstance(section_data, list):
68        section = opening_tag + section_data + closing_tag
69    elif isinstance(section_data, str):
70        section = opening_tag + (' ' + section_data).split('\n') + closing_tag
71    elif isinstance(section_data, int) or isinstance(section_data, float):
72        section = opening_tag + [' ' + str(section_data)] + closing_tag
73    return section
74
75
76def _list_format(data, per_line, style=WFX_FIELD_FMT):
77    """Format lists for pretty print."""
78    template = style * per_line
79    leftover = len(data) % per_line
80    # Template for last line.
81    last_template = style * leftover
82
83    pretty_list = [template%tuple(data[i:i+per_line])
84                    for i in range(0, len(data) - leftover, per_line)]
85    if leftover:
86        return pretty_list + [last_template%tuple(data[-1*leftover:])]
87    return  pretty_list
88
89
90class WFXWriter(filewriter.Writer):
91    """A writer for wfx files."""
92
93    required_attrs = ('natom', 'atomcoords', 'atomnos', 'gbasis', 'charge',
94                       'homos', 'mult', 'mocoeffs')
95
96    def _title(self):
97        """Section: Title
98        Return filename without extension to be used as title."""
99        title = "Written by cclib."
100        if self.jobfilename is not None:
101            return os.path.basename(os.path.splitext(self.jobfilename)[0]) +\
102                    '. ' + title
103        return title
104
105    def _keywords(self):
106        """Section: Keywords.
107        Return one of GTO, GIAO, CSGT keyword."""
108        # Currently only GTO is supported.
109        return 'GTO'
110
111    def _no_of_nuclei(self):
112        """Section: Number of Nuclei."""
113        return self.ccdata.natom
114
115    def _no_of_prims(self):
116        """Section: Number of Primitives."""
117        nprims = 0
118        for atom in self.ccdata.gbasis:
119            for prims in atom:
120                nprims += ORBITAL_COUNT[prims[0]] * len(prims[1])
121        return nprims
122
123    def _no_of_mos(self):
124        """Section: Number of Occupied Molecular Orbitals."""
125        return int(max(self.ccdata.homos)) + 1
126
127    def _no_of_perturbations(self):
128        """Section: Number of Perturbation.
129
130        This is usually zero.  For GIAO it should be 3
131        (corresponding to Lx, Ly and Lz), and
132        for CSGT it should be 6
133        (corresponding to Lx, Ly, Lz, Px, Py and Pz).
134        """
135        if 'GIAO' in self._keywords():
136            return 3
137        elif 'CSGT' in self._keywords():
138            return 6
139        return 0
140
141    def _nuclear_names(self):
142        """Section: Nuclear Names.
143        Names of nuclei present in the molecule.
144
145        O1
146        H2
147        H3
148        """
149        return [self.pt.element[Z]+str(i) for i, Z in
150                enumerate(self.ccdata.atomnos, start=1)]
151
152    def _atomic_nos(self):
153        """Section: Atomic Numbers."""
154        return [str(Z) for Z in self.ccdata.atomnos]
155
156    def _nuclear_charges(self):
157        """Section: Nuclear Charges."""
158        nuclear_charge = [WFX_FIELD_FMT % Z for Z in self.ccdata.atomnos]
159        if hasattr(self.ccdata, 'coreelectrons'):
160            nuclear_charge = [WFX_FIELD_FMT % Z
161                              for Z in self.ccdata.atomnos -
162                              self.ccdata.coreelectrons]
163        return nuclear_charge
164
165    def _nuclear_coords(self):
166        """Section: Nuclear Cartesian Coordinates.
167        Nuclear coordinates in Bohr."""
168        coord_template = WFX_FIELD_FMT * 3
169        to_bohr = lambda x: utils.convertor(x, 'Angstrom', 'bohr')
170        nuc_coords = [coord_template % tuple(to_bohr(coord))
171                      for coord in self.ccdata.atomcoords[-1]]
172        return nuc_coords
173
174    def _net_charge(self):
175        """Section: Net Charge.
176        Net charge on molecule."""
177        return WFX_FIELD_FMT % self.ccdata.charge
178
179    def _no_electrons(self):
180        """Section: Number of Electrons."""
181        return int(self.ccdata.nelectrons)
182
183    def _no_alpha_electrons(self):
184        """Section: Number of Alpha Electrons."""
185        no_electrons = numpy.sum(self.ccdata.atomnos - self.ccdata.coreelectrons) - self.ccdata.charge
186        no_alpha = (no_electrons + (self.ccdata.mult - 1))//2
187        return int(no_alpha)
188
189    def _no_beta_electrons(self):
190        """Section: Number of Beta Electrons."""
191        return int(self.ccdata.nelectrons - self._no_alpha_electrons())
192
193    def _spin_mult(self):
194        """Section: Electronic Spin Multiplicity"""
195        return self.ccdata.mult
196
197    def _prim_centers(self):
198        """Section: Primitive Centers.
199        List of nuclear numbers upon which the primitive basis functions
200        are centered."""
201        prim_centers = []
202        for nuc_num, atom in enumerate(self.ccdata.gbasis, start=1):
203            for prims in atom:
204                prim_centers += [nuc_num] * ORBITAL_COUNT[prims[0]]\
205                                * len(prims[1])
206
207        return _list_format(prim_centers, 10, '%d ')
208
209    def _rearrange_modata(self, data):
210        """Rearranges MO related data according the expected order of
211        Cartesian gaussian primitive types in wfx format.
212        cclib parses mocoeffs in the order they occur in output files.
213        """
214        prim_types = self._get_prim_types()
215        if isinstance(data, numpy.ndarray):
216            data = data.tolist()
217
218        pos_yyx = [key for key, val in enumerate(prim_types)
219                   if val == 17]
220        pos_yyz = [key for key, val in enumerate(prim_types)
221                   if val == 16]
222
223        if pos_yyx:
224            for pos in pos_yyx:
225                data.insert(pos-3, data.pop(pos))
226        if pos_yyz:
227            for pos in pos_yyz:
228                data.insert(pos+3, data.pop(pos + 1))
229
230        return data
231
232
233    def _get_prim_types(self):
234        """List of primitive types.
235        Definition of the Cartesian Gaussian primitive types is as follows:
236        1 S, 2 PX, 3 PY, 4 PZ, 5 DXX, 6 DYY, 7 DZZ, 8 DXY, 9 DXZ, 10 DYZ,
237        11 FXXX, 12 FYYY, 13 FZZZ, 14 FXXY, 15 FXXZ, 16 FYYZ, 17 FXYY,
238        18 FXZZ, 19 FYZZ, 20 FXYZ,
239        21 GXXXX, 22 GYYYY, 23 GZZZZ, 24 GXXXY, 25 GXXXZ, 26 GXYYY,
240        27 GYYYZ, 28 GXZZZ,
241        29 GYZZZ, 30 GXXYY, 31 GXXZZ, 32 GYYZZ, 33 GXXYZ, 34 GXYYZ,
242        35 GXYZZ,
243        36 HZZZZZ, 37 HYZZZZ, 38 HYYZZZ, 39 HYYYZZ,
244        40 HYYYYZ, 41 HYYYYY, 42 HXZZZZ, 43 HXYZZZ, 44 HXYYZZ,
245        45 HXYYYZ, 46 HXYYYY, 47 HXXZZZ, 48 HXXYZZ, 49 HXXYYZ,
246        50 HXXYYY, 51 HXXXZZ, 52 HXXXYZ, 53 HXXXYY, 54 HXXXXZ, 55 HXXXXY,
247        56 HXXXXX
248        Spherical basis are not currently supported by the writer.
249        """
250        prim_types = []
251        for atom in self.ccdata.gbasis:
252            for prims in atom:
253                prim_orb = []
254                for i in range(ORBITAL_COUNT[prims[0]]):
255                    prim_orb += [(ORBITAL_INDICES[prims[0]]  + i)]\
256                                * len(prims[1])
257                prim_types += prim_orb
258        return prim_types
259
260    def _prim_types(self):
261        """Section: Primitive Types."""
262        prim_types = self._get_prim_types()
263        # GAMESS specific reordering.
264        if self.ccdata.metadata['package'] == 'GAMESS':
265            prim_types = self._rearrange_modata(prim_types)
266        return _list_format(prim_types, 10, '%d ')
267
268    def _prim_exps(self):
269        """Section: Primitive Exponents.
270        Space-separated list of primitive exponents."""
271        prim_exps = []
272        for atom in self.ccdata.gbasis:
273            for prims in atom:
274                prim_exps += [prim[0] for prim in prims[1]]\
275                            * ORBITAL_COUNT[prims[0]]
276        return _list_format(prim_exps, 5)
277
278    def _mo_occup_nos(self):
279        """Section: Molecular Orbital Occupation Numbers."""
280        occup = []
281        electrons = self._no_electrons()
282        alpha = self._no_alpha_electrons()
283        beta = self._no_beta_electrons()
284        if len(self.ccdata.homos) == 1:
285            occup += [WFX_FIELD_FMT % (2)] * int(electrons / 2) +\
286                        [WFX_FIELD_FMT % (1)] * (electrons % 2)
287        else:
288            occup += [WFX_FIELD_FMT % (1)] *  + alpha +\
289                        [WFX_FIELD_FMT % (1)] * beta
290        return occup
291
292    def _mo_energies(self):
293        """Section: Molecular Orbital Energies."""
294        mo_energies = []
295        alpha_elctrons = self._no_alpha_electrons()
296        beta_electrons = self._no_beta_electrons()
297        for mo_energy in self.ccdata.moenergies[0][:alpha_elctrons]:
298            mo_energies.append(WFX_FIELD_FMT % (
299                utils.convertor(mo_energy, 'eV', 'hartree')))
300        if self.ccdata.mult > 1:
301            for mo_energy in self.ccdata.moenergies[1][:beta_electrons]:
302                mo_energies.append(WFX_FIELD_FMT % (
303                    utils.convertor(mo_energy, 'eV', 'hartree')))
304        return mo_energies
305
306    def _mo_spin_types(self):
307        """Section: Molecular Orbital Spin Types."""
308        spin_types = []
309        electrons = self._no_electrons()
310        alpha = self._no_alpha_electrons()
311        beta = self._no_beta_electrons()
312        if len(self.ccdata.homos) == 1:
313            spin_types += ['Alpha and Beta'] * int(electrons / 2) +\
314                            ['Alpha'] * (electrons % 2)
315        else:
316            spin_types += ['Alpha'] * alpha +\
317                            ['Beta'] * beta
318        return spin_types
319
320    def _normalize(self, prim_type, alpha=1.0):
321        """Normalization factor for Cartesian Gaussian Functions.
322
323        N**4 = (2/pi)**3 * 2**(l+m+n) * alpha**(3 + 2(l+m+n)) /
324                            ((2l-1)!! * (2m-1)!! * (2n-1)!!)**2
325            = (2/pi)**3 * 2**(L) * alpha**(3 + 2L) /
326                            M**2,
327        L = l+m+n,
328        M = ((2l-1)!! * (2m-1)!! * (2n-1)!!)
329        """
330        L = _L[prim_type]
331        M = _M[prim_type]
332        norm_four = PI_CUBE_INV * 2**(4*L) * alpha**(3+2*L) / M
333        norm = numpy.power(norm_four, 1/4.0)
334        return norm
335
336    def _rearrange_mocoeffs(self, mocoeffs):
337        """Rearrange cartesian F functions in mocoeffs.
338        Expected order:
339        xxx, yyy, zzz, xyy, xxy, xxz, xzz, yzz, yyz, xyz
340        cclib's order for GAMESS:
341        XXX, YYY, ZZZ, XXY, XXZ, YYX, YYZ, ZZX, ZZY, XYZ
342        """
343
344        aonames = self.ccdata.aonames
345        mocoeffs = mocoeffs.tolist()
346
347        pos_yyx = [key for key, val in enumerate(aonames)
348                   if '_YYX' in val]
349        pos_yyz = [key for key, val in enumerate(aonames)
350                   if '_YYZ' in val]
351
352        if pos_yyx:
353            for pos in pos_yyx:
354                mocoeffs.insert(pos-2, mocoeffs.pop(pos))
355        if pos_yyz:
356            for pos in pos_yyz:
357                mocoeffs.insert(pos+2, mocoeffs.pop(pos))
358
359        return mocoeffs
360
361    def _norm_mat(self):
362        """Calculate normalization matrix for normalizing MOcoeffs."""
363        alpha = []
364        prim_coeff = []
365        mo_count = []
366        prim_type = self._get_prim_types()
367
368        for atom in self.ccdata.gbasis:
369            for prims in atom:
370                prim_orb = []
371                mo_count += [len(prims[1])] * ORBITAL_COUNT[prims[0]]
372                for i in range(ORBITAL_COUNT[prims[0]]):
373                    norb = ORBITAL_INDICES[prims[0]]
374                    prim_orb += [norb + i]
375                    alpha += [prim[0] for prim in prims[1]]
376                    prim_coeff += [prim[1] for prim in prims[1]]
377
378        # GAMESS specific reordering.
379        if self.ccdata.metadata['package'] == 'GAMESS':
380            prim_type = self._rearrange_modata(self._get_prim_types())
381            alpha = self._rearrange_modata(alpha)
382            prim_coeff = self._rearrange_modata(prim_coeff)
383
384        norm_mat = [self._normalize(prim_type[i], alpha[i]) * prim_coeff[i]
385                    for i in range(len(prim_coeff))]
386
387        return (norm_mat, mo_count, prim_coeff)
388
389    def _nmos(self):
390        """Return number of molecular orbitals to be printed."""
391
392        return self.ccdata.nelectrons if self.ccdata.mult > 1\
393                else self._no_of_mos()
394
395    def _prim_mocoeff(self, mo_count):
396        """Return primitve mocoeffs array."""
397        prim_mocoeff = []
398
399        for i in range(len(self.ccdata.mocoeffs)):
400            for j in range(self._nmos()):
401                mocoeffs = self.ccdata.mocoeffs[i][j]
402                if self.ccdata.metadata['package'] == 'GAMESS':
403                    mocoeffs = self._rearrange_mocoeffs(self.ccdata.mocoeffs[i][j])
404                for k, mocoeff in enumerate(mocoeffs):
405                    prim_mocoeff += [mocoeff] * mo_count[k]
406
407        return prim_mocoeff
408
409    def _normalized_mocoeffs(self):
410        """Raw-Primitive Expansion coefficients for each normalized MO."""
411        # Normalization Matrix.
412        norm_mat, mo_count, prim_coeff = self._norm_mat()
413
414        prim_mocoeff = self._prim_mocoeff(mo_count)
415
416        norm_mocoeffs = []
417        for mo_num in range(self._nmos()):
418            norm_mocoeffs.append([norm_mat[i] *
419                                  prim_mocoeff[i + mo_num * len(prim_coeff)]
420                                  for i in range(len(prim_coeff))])
421
422        return norm_mocoeffs
423
424    def _mo_prim_coeffs(self):
425        """Section: Molecular Orbital Primitive Coefficients."""
426        # Normalized MO Coeffs.
427        norm_mocoeffs = self._normalized_mocoeffs()
428        mocoeffs_section = []
429
430        for mo_num, mocoeffs in enumerate(norm_mocoeffs):
431            mocoeffs_section.extend(_section('MO Number', mo_num + 1))
432            mocoeffs_section.extend(_list_format
433                                    (mocoeffs, 5))
434        return mocoeffs_section
435
436    def _energy(self):
437        """Section: Energy = T + Vne + Vee + Vnn.
438        The total energy of the molecule.
439        HF and KSDFT: SCF energy        (scfenergies),
440        MP2         : MP2 total energy  (mpenergies),
441        CCSD        : CCSD total energy (ccenergies).
442        """
443        energy = 0
444        if hasattr(self.ccdata, 'ccenergies'):
445            energy = self.ccdata.ccenergies[-1]
446        elif hasattr(self.ccdata, 'mpenergies'):
447            energy = self.ccdata.mpenergies[-1][-1]
448        elif hasattr(self.ccdata, 'scfenergies'):
449            energy = self.ccdata.scfenergies[-1]
450        else:
451            raise filewriter.MissingAttributeError(
452                'scfenergies/mpenergies/ccenergies')
453        return WFX_FIELD_FMT % (utils.convertor(energy, 'eV', 'hartree'))
454
455    def _virial_ratio(self):
456        """Ratio of kinetic energy to potential energy."""
457        # Hardcoding expected value for Required Field.
458        return WFX_FIELD_FMT % (2.0)
459
460    def generate_repr(self):
461        """Generate the wfx representation of the logfile data."""
462
463        # sections:(Function returning data for section,
464        #           Section heading,
465        #           Required)
466
467        sections = [
468            (self._title, "Title", True),
469            (self._keywords, "Keywords", True),
470            (self._no_of_nuclei, "Number of Nuclei", True),
471            (self._no_of_prims, "Number of Primitives", True),
472            (self._no_of_mos, "Number of Occupied Molecular Orbitals", True),
473            (self._no_of_perturbations, "Number of Perturbations", True),
474            (self._nuclear_names, "Nuclear Names", True),
475            (self._atomic_nos, "Atomic Numbers", True),
476            (self._nuclear_charges, "Nuclear Charges", True),
477            (self._nuclear_coords, "Nuclear Cartesian Coordinates", True),
478            (self._net_charge, "Net Charge", True),
479            (self._no_electrons, "Number of Electrons", True),
480            (self._no_alpha_electrons, "Number of Alpha Electrons", True),
481            (self._no_beta_electrons, "Number of Beta Electrons", True),
482            (self._spin_mult, "Electronic Spin Multiplicity", False),
483            # (self._model, "Model", False),
484            (self._prim_centers, "Primitive Centers", True),
485            (self._prim_types, "Primitive Types", True),
486            (self._prim_exps, "Primitive Exponents", True),
487            (self._mo_occup_nos,
488             "Molecular Orbital Occupation Numbers", True),
489            (self._mo_energies, "Molecular Orbital Energies", True),
490            (self._mo_spin_types, "Molecular Orbital Spin Types", True),
491            (self._mo_prim_coeffs,
492             "Molecular Orbital Primitive Coefficients", True),
493            (self._energy, "Energy = T + Vne + Vee + Vnn", True),
494            # (self._nuc_energy_gradients,
495            #  "Nuclear Cartesian Energy Gradients", False),
496            # (self._nuc_virial,
497            #  "Nuclear Virial of Energy-Gradient-Based Forces on Nuclei, W",
498            #  False),
499            (self._virial_ratio, "Virial Ratio (-V/T)", True),
500        ]
501
502        wfx_lines = []
503
504        for section_module, section_name, section_required in sections:
505            try:
506                section_data = section_module()
507                wfx_lines.extend(_section(section_name, section_data))
508            except:
509                if section_required:
510                    raise filewriter.MissingAttributeError(
511                        'Unable to write required wfx section: '
512                        + section_name)
513
514        wfx_lines.append('')
515        return '\n'.join(wfx_lines)
516