1import argparse
2from typing import List, Union
3
4from ase import Atoms
5import numpy as np
6
7from gpaw import GPAW
8from gpaw.point_groups import SymmetryChecker, point_group_names
9from gpaw.typing import Array1D, Array3D
10
11
12class CubeCalc:
13    """Wrap cube-file in a calculator."""
14    def __init__(self, function: Array3D, atoms: Atoms):
15        self.function = function
16        self.atoms = atoms
17
18    def get_pseudo_wave_function(self,
19                                 band: int,
20                                 spin: int,
21                                 pad: bool) -> Array3D:
22        return self.function
23
24    def get_eigenvalues(self, spin: int) -> Array1D:
25        return np.zeros(1)
26
27    def get_number_of_spins(self):
28        return 1
29
30
31def main(argv: List[str] = None) -> None:
32    parser = argparse.ArgumentParser(
33        prog='python3 -m gpaw.point_groups',
34        description='Analyse point-group of atoms and wave-functions.')
35    add = parser.add_argument
36    add('pg', metavar='point-group', choices=point_group_names,
37        help='Name of point-group: C2, C2v, C3v, D2d, D3h, D5, D5h, '
38        'Ico, Ih, Oh, Td or Th.')
39    add('file', metavar='input-file',
40        help='Cube-file, gpw-file or something else with atoms in it.')
41    add('-c', '--center', help='Center specified as one or more atoms.  '
42        'Use chemical symbols or sequence numbers.')
43    add('-r', '--radius', default=2.5,
44        help='Cutoff radius (in Å) used for wave function overlaps.')
45    add('-b', '--bands', default=':', metavar='N1:N2',
46        help='Band range.')
47    add('-a', '--axes', default='',
48        help='Example: "-a z=x,x=-y".')
49    if hasattr(parser, 'parse_intermixed_args'):
50        args = parser.parse_intermixed_args(argv)  # needs Python 3.7
51    else:
52        args = parser.parse_args(argv)
53
54    calc: Union[None, GPAW, CubeCalc]
55
56    if args.file.endswith('.gpw'):
57        calc = GPAW(args.file)
58        atoms = calc.atoms
59        n1, n2 = (int(x) if x else 0 for x in args.bands.split(':'))
60    elif args.file.endswith('.cube'):
61        from ase.io.cube import read_cube
62        with open(args.file) as fd:
63            dct = read_cube(fd)
64        calc = CubeCalc(dct['data'], dct['atoms'])
65        atoms = dct['atoms']
66        n1 = 0
67        n2 = 1
68    else:
69        from ase.io import read
70        atoms = read(args.file)
71        calc = None
72
73    if args.center:
74        symbols = set(args.center.split(','))
75        center = np.zeros(3)
76        n = 0
77        for a, (symbol, position) in enumerate(zip(atoms.symbols,
78                                                   atoms.positions)):
79            if symbol in symbols or str(a) in symbols:
80                center += position
81                n += 1
82        center /= n
83    else:
84        center = atoms.cell.sum(0) / 2
85    print('Center:', center, f'(atoms: {n})')
86
87    radius = float(args.radius)
88
89    kwargs = {}
90    for axis in args.axes.split(',') if args.axes else []:
91        axis1, axis2 = axis.split('=')
92        kwargs[axis1] = axis2
93
94    checker = SymmetryChecker(args.pg, center, radius, **kwargs)
95
96    ok = checker.check_atoms(atoms)
97    print(f'{args.pg}-symmetry:', 'Yes' if ok else 'No')
98
99    if calc:
100        nspins = calc.get_number_of_spins()
101        for spin in range(nspins):
102            if nspins == 2:
103                print('Spin', ['up', 'down'][spin])
104            checker.check_calculation(calc, n1, n2, spin=spin)
105