1# coding: utf-8
2"""
3mayavi_ toolkit.
4
5WARNING: This code is still under development.
6"""
7import itertools
8import numpy as np
9
10DEFAULT_FIGURE_KWARGS = dict(size=(1024, 768), bgcolor=(1, 1, 1), fgcolor=(0, 0, 0))
11
12
13def get_fig_mlab(figure=None, **kwargs):  # pragma: no cover
14    try:
15        from mayavi import mlab
16    except ImportError as exc:
17        print("mayavi is not installed. Use `conda install mayavi` or `pip install mayavi`")
18        raise exc
19
20    # To use the full envisage application
21    #mlab.options.backend = "envisage"
22    #mlab.options.backend = "test"
23    #mlab.options.offscreen = True
24
25    if figure is None:
26        # Add defaults
27        for k, v in DEFAULT_FIGURE_KWARGS.items():
28            if k not in kwargs: kwargs[k] = v
29        figure = mlab.figure(**kwargs)
30
31    return figure, mlab
32
33
34def plot_wigner_seitz(lattice, figure=None, **kwargs):  # pragma: no cover
35    """
36    Adds the skeleton of the Wigner-Seitz cell of the lattice to a mayavi_ figure
37
38    Args:
39        lattice: Reciprocal-space |Lattice| object
40        figure: mayavi figure, None to plot on the curretn figure
41        kwargs: kwargs passed to the mayavi function ``plot3d``. Color defaults to black
42            and line_width to 1.
43
44    Returns: mayavi figure
45    """
46    figure, mlab = get_fig_mlab(figure=figure)
47
48    if "color" not in kwargs:
49        kwargs["color"] = (0, 0, 0)
50    if "line_width" not in kwargs:
51        kwargs["line_width"] = 1
52    if "tube_radius" not in kwargs:
53        kwargs["tube_radius"] = None
54
55    bz = lattice.get_wigner_seitz_cell()
56    for iface in range(len(bz)):
57        for line in itertools.combinations(bz[iface], 2):
58            for jface in range(len(bz)):
59                if iface < jface and any(np.all(line[0] == x) for x in bz[jface])\
60                        and any(np.all(line[1] == x) for x in bz[jface]):
61                    #do_plot = True
62                    #if in_unit_cell:
63                    #    kred0 = lattice.get_fractional_coords(line[0])
64                    #    kred1 = lattice.get_fractional_coords(line[1])
65                    #    do_plot = np.alltrue((kred0 >= 0) & (kred0 <= 0.5) &
66                    #                         (kred1 >= 0) & (kred1 <= 0.5))
67                    #    print(kred0, kred1, do_plot)
68                    #if not do_plot: continue
69                    mlab.plot3d(*zip(line[0], line[1]), figure=figure, **kwargs)
70
71    return figure
72
73
74def plot_unit_cell(lattice, figure=None, **kwargs):  # pragma: no cover
75    """
76    Adds the unit cell of the lattice to a mayavi_ figure.
77
78    Args:
79        lattice: Lattice object
80        figure: mayavi figure, None to plot on the curretn figure
81        kwargs: kwargs passed to the mayavi function ``plot3d``. Color defaults to black
82            and line_width to 1.
83
84    Returns: mayavi figure
85    """
86    figure, mlab = get_fig_mlab(figure=figure)
87
88    if "color" not in kwargs:
89        kwargs["color"] = (0, 0, 0)
90    if "line_width" not in kwargs:
91        kwargs["line_width"] = 1
92    if "tube_radius" not in kwargs:
93        kwargs["tube_radius"] = None
94
95    v = 8 * [None]
96    v[0] = lattice.get_cartesian_coords([0.0, 0.0, 0.0])
97    v[1] = lattice.get_cartesian_coords([1.0, 0.0, 0.0])
98    v[2] = lattice.get_cartesian_coords([1.0, 1.0, 0.0])
99    v[3] = lattice.get_cartesian_coords([0.0, 1.0, 0.0])
100    v[4] = lattice.get_cartesian_coords([0.0, 1.0, 1.0])
101    v[5] = lattice.get_cartesian_coords([1.0, 1.0, 1.0])
102    v[6] = lattice.get_cartesian_coords([1.0, 0.0, 1.0])
103    v[7] = lattice.get_cartesian_coords([0.0, 0.0, 1.0])
104
105    for i, j in ((0, 1), (1, 2), (2, 3), (0, 3), (3, 4), (4, 5), (5, 6),
106                 (6, 7), (7, 4), (0, 7), (1, 6), (2, 5), (3, 4)):
107        mlab.plot3d(*zip(v[i], v[j]), figure=figure, **kwargs)
108
109    #mlab.xlabel("x-axis")
110    #mlab.ylabel("y-axis")
111    #mlab.zlabel("z-axis")
112
113    return figure
114
115
116def plot_lattice_vectors(lattice, figure=None, **kwargs): # pragma: no cover
117    """
118    Adds the basis vectors of the lattice provided to a mayavi_ figure.
119
120    Args:
121        lattice: |Lattice| object.
122        figure: mayavi figure, None if a new figure should be created.
123        kwargs: kwargs passed to the mayavi function ``plot3d``. Color defaults to black
124            and line_width to 1.
125
126    Returns: mayavi figure
127    """
128    figure, mlab = get_fig_mlab(figure=figure)
129
130    if "color" not in kwargs:
131        kwargs["color"] = (0, 0, 0)
132    if "line_width" not in kwargs:
133        kwargs["line_width"] = 1
134    if "tube_radius" not in kwargs:
135        kwargs["tube_radius"] = None
136
137    vertex1 = lattice.get_cartesian_coords([0.0, 0.0, 0.0])
138    vertex2 = lattice.get_cartesian_coords([1.0, 0.0, 0.0])
139    mlab.plot3d(*zip(vertex1, vertex2), figure=figure, **kwargs)
140    vertex2 = lattice.get_cartesian_coords([0.0, 1.0, 0.0])
141    mlab.plot3d(*zip(vertex1, vertex2), figure=figure, **kwargs)
142    vertex2 = lattice.get_cartesian_coords([0.0, 0.0, 1.0])
143    mlab.plot3d(*zip(vertex1, vertex2), figure=figure, **kwargs)
144
145    return figure
146
147
148def plot_structure(structure, frac_coords=False, to_unit_cell=False, style="points+labels",
149                   unit_cell_color=(0, 0, 0), color_scheme="VESTA", figure=None, show=False, **kwargs):  # pragma: no cover
150    """
151    Plot structure with mayavi.
152
153    Args:
154        structure: |Structure| object
155        frac_coords:
156        to_unit_cell: True if sites should be wrapped into the first unit cell.
157        style: "points+labels" to show atoms sites with labels.
158        unit_cell_color:
159        color_scheme: color scheme for atom types. Allowed values in ("Jmol", "VESTA")
160        figure:
161        kwargs:
162
163    Returns: mayavi figure
164    """
165    figure, mlab = get_fig_mlab(figure=figure)
166
167    #if not frac_coords:
168    plot_unit_cell(structure.lattice, color=unit_cell_color, figure=figure)
169    from pymatgen.analysis.molecule_structure_comparator import CovalentRadius
170    from pymatgen.vis.structure_vtk import EL_COLORS
171
172    for site in structure:
173        symbol = site.specie.symbol
174        color = tuple(i / 255 for i in EL_COLORS[color_scheme][symbol])
175        radius = CovalentRadius.radius[symbol]
176        if to_unit_cell and hasattr(site, "to_unit_cell"): site = site.to_unit_cell
177        x, y, z = site.frac_coords if frac_coords else site.coords
178
179        if "points" in style:
180            mlab.points3d(x, y, z, figure=figure, scale_factor=radius,
181                          resolution=20, color=color, scale_mode='none', **kwargs)
182        if "labels" in style:
183            mlab.text3d(x, y, z, symbol, figure=figure, color=(0, 0, 0), scale=0.2)
184
185    if show: mlab.show()
186    return figure
187
188
189def plot_labels(labels, lattice=None, coords_are_cartesian=False, figure=None, **kwargs):  # pragma: no cover
190    """
191    Adds labels to a mayavi_ figure.
192
193    Args:
194        labels: dict containing the label as a key and the coordinates as value.
195        lattice: |Lattice| object used to convert from reciprocal to cartesian coordinates
196        coords_are_cartesian: Set to True if you are providing.
197            coordinates in cartesian coordinates. Defaults to False.
198            Requires lattice if False.
199        figure: mayavi figure, None to plot on the curretn figure
200        kwargs: kwargs passed to the mayavi function `text3d`. Color defaults to blue and size to 25.
201
202    Returns: mayavi figure
203    """
204    figure, mlab = get_fig_mlab(figure=figure)
205
206    #if "color" not in kwargs:
207    #    kwargs["color"] = "b"
208    #if "size" not in kwargs:
209    #    kwargs["size"] = 25
210    #if "width" not in kwargs:
211    #    kwargs["width"] = 0.8
212    if "scale" not in kwargs:
213        kwargs["scale"] = 0.1
214
215    for k, coords in labels.items():
216        label = k
217        if k.startswith("\\") or k.find("_") != -1:
218            label = "$" + k + "$"
219        off = 0.01
220        if coords_are_cartesian:
221            coords = np.array(coords)
222        else:
223            if lattice is None:
224                raise ValueError("coords_are_cartesian False requires the lattice")
225            coords = lattice.get_cartesian_coords(coords)
226        x, y, z = coords + off
227        mlab.text3d(x, y, z, label, figure=figure, **kwargs)
228
229    return figure
230
231
232class MayaviFieldAnimator(object): # pragma: no cover
233
234    def __init__(self, filepaths):
235        self.filepaths = filepaths
236        self.num_files = len(filepaths)
237
238    def volume_animate(self):
239        from abipy import abilab
240        with abilab.abiopen(self.filepaths[0]) as nc:
241            nsppol, nspden, nspinor = nc.nsppol, nc.nspden, nc.nspinor
242            structure = nc.structure
243            datar = nc.field.datar
244            # [nspden, nx, ny, nz] array
245            nx, ny, nz = datar.shape[1:]
246            s = datar[0]
247            print(s.dtype, s.shape)
248
249        #cart_coords = np.empty((nx*ny*nz, 3))
250        #cnt = 0
251        #for i in range(nx):
252        #    for j in range(ny):
253        #        for k in range(nz):
254        #            cart_coords[ctn, :] = (i/nx, j/ny, k/nz)
255        #            cnt += 1
256        #cart_coords = structure.lattice.get_cartesian_coords(cart_coords)
257        # We reorder the points, scalars and vectors so this is as per VTK's
258        # requirement of x first, y next and z last.
259        #pts = pts.transpose(2, 1, 0, 3).copy()
260        #pts.shape = pts.size / 3, 3
261        #scalars = scalars.T.copy()
262        #vectors = vectors.transpose(2, 1, 0, 3).copy()
263        #vectors.shape = vectors.size / 3, 3
264
265        #from tvtk.api import tvtk
266        #sgrid = tvtk.StructuredGrid(dimensions=(dims[1], dims[0], dims[2]))
267        #sgrid.points = pts
268        #s = random.random((dims[0]*dims[1]*dims[2]))
269        #sgrid.point_data.scalars = ravel(s.copy())
270        #sgrid.point_data.scalars.name = 'scalars'
271
272        figure, mlab = get_fig_mlab(figure=None)
273        source = mlab.pipeline.scalar_field(s)
274        data_min, data_max = s.min(), s.max()
275        print(data_min, data_max)
276        #mlab.pipeline.volume(source)
277        #                     #vmin=data_min + 0.65 * (data_max - data_min),
278        #                     #vmax=data_min + 0.9 * (data_max - data_min))
279        #mlab.pipeline.iso_surface(source)
280        mlab.pipeline.image_plane_widget(source, plane_orientation='x_axes', slice_index=0)
281        mlab.pipeline.image_plane_widget(source, plane_orientation='y_axes', slice_index=0)
282        mlab.pipeline.image_plane_widget(source, plane_orientation='z_axes', slice_index=0)
283
284        @mlab.show
285        @mlab.animate(delay=1000, ui=True)
286        def anim():
287            """Animate."""
288            t = 1
289            while True:
290                #vmin, vmax = .1 * np.max(data[t]), .2 * np.max(data[t])
291                #print 'animation t = ',tax[t],', max = ',np.max(data[t])
292                with abilab.abiopen(self.filepaths[t]) as nc:
293                    print("Animation step", t, "from file:", self.filepaths[t])
294                    #nsppol, nspden, nspinor = nc.nsppol, nc.nspden, nc.nspinor
295                    datar = nc.field.datar
296                    # [nspden, nx, ny, nz] array
297                    #nx, ny, nz = datar.shape[1:]
298                    scalars = datar[0]
299
300                #data_min, data_max = scalars.min(), scalars.max(),
301                #mlab.pipeline.volume(source, vmin=data_min + 0.65 * (data_max - data_min),
302                #                     vmax=data_min + 0.9 * (data_max - data_min))
303                source.mlab_source.scalars = scalars
304
305                t = (t + 1) % self.num_files
306                yield
307
308        anim()
309