1import optparse
2
3import numpy as np
4
5from ase.data import covalent_radii
6from ase.io.cube import read_cube_data
7from ase.data.colors import cpk_colors
8from ase.calculators.calculator import get_calculator_class
9
10
11def plot(atoms, data, contours):
12    """Plot atoms, unit-cell and iso-surfaces using Mayavi.
13
14    Parameters:
15
16    atoms: Atoms object
17        Positions, atomiz numbers and unit-cell.
18    data: 3-d ndarray of float
19        Data for iso-surfaces.
20    countours: list of float
21        Contour values.
22    """
23
24    # Delay slow imports:
25    from mayavi import mlab
26
27    mlab.figure(1, bgcolor=(1, 1, 1))  # make a white figure
28
29    # Plot the atoms as spheres:
30    for pos, Z in zip(atoms.positions, atoms.numbers):
31        mlab.points3d(*pos,
32                      scale_factor=covalent_radii[Z],
33                      resolution=20,
34                      color=tuple(cpk_colors[Z]))
35
36    # Draw the unit cell:
37    A = atoms.cell
38    for i1, a in enumerate(A):
39        i2 = (i1 + 1) % 3
40        i3 = (i1 + 2) % 3
41        for b in [np.zeros(3), A[i2]]:
42            for c in [np.zeros(3), A[i3]]:
43                p1 = b + c
44                p2 = p1 + a
45                mlab.plot3d([p1[0], p2[0]],
46                            [p1[1], p2[1]],
47                            [p1[2], p2[2]],
48                            tube_radius=0.1)
49
50    cp = mlab.contour3d(data, contours=contours, transparent=True,
51                        opacity=0.5, colormap='hot')
52    # Do some tvtk magic in order to allow for non-orthogonal unit cells:
53    polydata = cp.actor.actors[0].mapper.input
54    pts = np.array(polydata.points) - 1
55    # Transform the points to the unit cell:
56    polydata.points = np.dot(pts, A / np.array(data.shape)[:, np.newaxis])
57
58    # Apparently we need this to redraw the figure, maybe it can be done in
59    # another way?
60    mlab.view(azimuth=155, elevation=70, distance='auto')
61    # Show the 3d plot:
62    mlab.show()
63
64
65description = """\
66Plot iso-surfaces from a cube-file or a wave function or an electron
67density from a calculator-restart file."""
68
69
70def main(args=None):
71    parser = optparse.OptionParser(usage='%prog [options] filename',
72                                   description=description)
73    add = parser.add_option
74    add('-n', '--band-index', type=int, metavar='INDEX',
75        help='Band index counting from zero.')
76    add('-s', '--spin-index', type=int, metavar='SPIN',
77        help='Spin index: zero or one.')
78    add('-e', '--electrostatic-potential', action='store_true',
79        help='Plot the electrostatic potential.')
80    add('-c', '--contours', default='4',
81        help='Use "-c 3" for 3 contours or "-c -0.5,0.5" for specific ' +
82        'values.  Default is four contours.')
83    add('-r', '--repeat', help='Example: "-r 2,2,2".')
84    add('-C', '--calculator-name', metavar='NAME', help='Name of calculator.')
85
86    opts, args = parser.parse_args(args)
87    if len(args) != 1:
88        parser.error('Incorrect number of arguments')
89
90    arg = args[0]
91    if arg.endswith('.cube'):
92        data, atoms = read_cube_data(arg)
93    else:
94        calc = get_calculator_class(opts.calculator_name)(arg, txt=None)
95        atoms = calc.get_atoms()
96        if opts.band_index is None:
97            if opts.electrostatic_potential:
98                data = calc.get_electrostatic_potential()
99            else:
100                data = calc.get_pseudo_density(opts.spin_index)
101        else:
102            data = calc.get_pseudo_wave_function(opts.band_index,
103                                                 opts.spin_index or 0)
104            if data.dtype == complex:
105                data = abs(data)
106
107    mn = data.min()
108    mx = data.max()
109    print('Min: %16.6f' % mn)
110    print('Max: %16.6f' % mx)
111
112    if opts.contours.isdigit():
113        n = int(opts.contours)
114        d = (mx - mn) / n
115        contours = np.linspace(mn + d / 2, mx - d / 2, n).tolist()
116    else:
117        contours = [float(x) for x in opts.contours.rstrip(',').split(',')]
118
119    if len(contours) == 1:
120        print('1 contour:', contours[0])
121    else:
122        print('%d contours: %.6f, ..., %.6f' %
123              (len(contours), contours[0], contours[-1]))
124
125    if opts.repeat:
126        repeat = [int(r) for r in opts.repeat.split(',')]
127        data = np.tile(data, repeat)
128        atoms *= repeat
129
130    plot(atoms, data, contours)
131
132
133if __name__ == '__main__':
134    main()
135