1# coding: utf-8 2""" 3mayavi_ toolkit. 4 5WARNING: This code is still under development. 6""" 7import itertools 8import numpy as np 9 10DEFAULT_FIGURE_KWARGS = dict(size=(1024, 768), bgcolor=(1, 1, 1), fgcolor=(0, 0, 0)) 11 12 13def get_fig_mlab(figure=None, **kwargs): # pragma: no cover 14 try: 15 from mayavi import mlab 16 except ImportError as exc: 17 print("mayavi is not installed. Use `conda install mayavi` or `pip install mayavi`") 18 raise exc 19 20 # To use the full envisage application 21 #mlab.options.backend = "envisage" 22 #mlab.options.backend = "test" 23 #mlab.options.offscreen = True 24 25 if figure is None: 26 # Add defaults 27 for k, v in DEFAULT_FIGURE_KWARGS.items(): 28 if k not in kwargs: kwargs[k] = v 29 figure = mlab.figure(**kwargs) 30 31 return figure, mlab 32 33 34def plot_wigner_seitz(lattice, figure=None, **kwargs): # pragma: no cover 35 """ 36 Adds the skeleton of the Wigner-Seitz cell of the lattice to a mayavi_ figure 37 38 Args: 39 lattice: Reciprocal-space |Lattice| object 40 figure: mayavi figure, None to plot on the curretn figure 41 kwargs: kwargs passed to the mayavi function ``plot3d``. Color defaults to black 42 and line_width to 1. 43 44 Returns: mayavi figure 45 """ 46 figure, mlab = get_fig_mlab(figure=figure) 47 48 if "color" not in kwargs: 49 kwargs["color"] = (0, 0, 0) 50 if "line_width" not in kwargs: 51 kwargs["line_width"] = 1 52 if "tube_radius" not in kwargs: 53 kwargs["tube_radius"] = None 54 55 bz = lattice.get_wigner_seitz_cell() 56 for iface in range(len(bz)): 57 for line in itertools.combinations(bz[iface], 2): 58 for jface in range(len(bz)): 59 if iface < jface and any(np.all(line[0] == x) for x in bz[jface])\ 60 and any(np.all(line[1] == x) for x in bz[jface]): 61 #do_plot = True 62 #if in_unit_cell: 63 # kred0 = lattice.get_fractional_coords(line[0]) 64 # kred1 = lattice.get_fractional_coords(line[1]) 65 # do_plot = np.alltrue((kred0 >= 0) & (kred0 <= 0.5) & 66 # (kred1 >= 0) & (kred1 <= 0.5)) 67 # print(kred0, kred1, do_plot) 68 #if not do_plot: continue 69 mlab.plot3d(*zip(line[0], line[1]), figure=figure, **kwargs) 70 71 return figure 72 73 74def plot_unit_cell(lattice, figure=None, **kwargs): # pragma: no cover 75 """ 76 Adds the unit cell of the lattice to a mayavi_ figure. 77 78 Args: 79 lattice: Lattice object 80 figure: mayavi figure, None to plot on the curretn figure 81 kwargs: kwargs passed to the mayavi function ``plot3d``. Color defaults to black 82 and line_width to 1. 83 84 Returns: mayavi figure 85 """ 86 figure, mlab = get_fig_mlab(figure=figure) 87 88 if "color" not in kwargs: 89 kwargs["color"] = (0, 0, 0) 90 if "line_width" not in kwargs: 91 kwargs["line_width"] = 1 92 if "tube_radius" not in kwargs: 93 kwargs["tube_radius"] = None 94 95 v = 8 * [None] 96 v[0] = lattice.get_cartesian_coords([0.0, 0.0, 0.0]) 97 v[1] = lattice.get_cartesian_coords([1.0, 0.0, 0.0]) 98 v[2] = lattice.get_cartesian_coords([1.0, 1.0, 0.0]) 99 v[3] = lattice.get_cartesian_coords([0.0, 1.0, 0.0]) 100 v[4] = lattice.get_cartesian_coords([0.0, 1.0, 1.0]) 101 v[5] = lattice.get_cartesian_coords([1.0, 1.0, 1.0]) 102 v[6] = lattice.get_cartesian_coords([1.0, 0.0, 1.0]) 103 v[7] = lattice.get_cartesian_coords([0.0, 0.0, 1.0]) 104 105 for i, j in ((0, 1), (1, 2), (2, 3), (0, 3), (3, 4), (4, 5), (5, 6), 106 (6, 7), (7, 4), (0, 7), (1, 6), (2, 5), (3, 4)): 107 mlab.plot3d(*zip(v[i], v[j]), figure=figure, **kwargs) 108 109 #mlab.xlabel("x-axis") 110 #mlab.ylabel("y-axis") 111 #mlab.zlabel("z-axis") 112 113 return figure 114 115 116def plot_lattice_vectors(lattice, figure=None, **kwargs): # pragma: no cover 117 """ 118 Adds the basis vectors of the lattice provided to a mayavi_ figure. 119 120 Args: 121 lattice: |Lattice| object. 122 figure: mayavi figure, None if a new figure should be created. 123 kwargs: kwargs passed to the mayavi function ``plot3d``. Color defaults to black 124 and line_width to 1. 125 126 Returns: mayavi figure 127 """ 128 figure, mlab = get_fig_mlab(figure=figure) 129 130 if "color" not in kwargs: 131 kwargs["color"] = (0, 0, 0) 132 if "line_width" not in kwargs: 133 kwargs["line_width"] = 1 134 if "tube_radius" not in kwargs: 135 kwargs["tube_radius"] = None 136 137 vertex1 = lattice.get_cartesian_coords([0.0, 0.0, 0.0]) 138 vertex2 = lattice.get_cartesian_coords([1.0, 0.0, 0.0]) 139 mlab.plot3d(*zip(vertex1, vertex2), figure=figure, **kwargs) 140 vertex2 = lattice.get_cartesian_coords([0.0, 1.0, 0.0]) 141 mlab.plot3d(*zip(vertex1, vertex2), figure=figure, **kwargs) 142 vertex2 = lattice.get_cartesian_coords([0.0, 0.0, 1.0]) 143 mlab.plot3d(*zip(vertex1, vertex2), figure=figure, **kwargs) 144 145 return figure 146 147 148def plot_structure(structure, frac_coords=False, to_unit_cell=False, style="points+labels", 149 unit_cell_color=(0, 0, 0), color_scheme="VESTA", figure=None, show=False, **kwargs): # pragma: no cover 150 """ 151 Plot structure with mayavi. 152 153 Args: 154 structure: |Structure| object 155 frac_coords: 156 to_unit_cell: True if sites should be wrapped into the first unit cell. 157 style: "points+labels" to show atoms sites with labels. 158 unit_cell_color: 159 color_scheme: color scheme for atom types. Allowed values in ("Jmol", "VESTA") 160 figure: 161 kwargs: 162 163 Returns: mayavi figure 164 """ 165 figure, mlab = get_fig_mlab(figure=figure) 166 167 #if not frac_coords: 168 plot_unit_cell(structure.lattice, color=unit_cell_color, figure=figure) 169 from pymatgen.analysis.molecule_structure_comparator import CovalentRadius 170 from pymatgen.vis.structure_vtk import EL_COLORS 171 172 for site in structure: 173 symbol = site.specie.symbol 174 color = tuple(i / 255 for i in EL_COLORS[color_scheme][symbol]) 175 radius = CovalentRadius.radius[symbol] 176 if to_unit_cell and hasattr(site, "to_unit_cell"): site = site.to_unit_cell 177 x, y, z = site.frac_coords if frac_coords else site.coords 178 179 if "points" in style: 180 mlab.points3d(x, y, z, figure=figure, scale_factor=radius, 181 resolution=20, color=color, scale_mode='none', **kwargs) 182 if "labels" in style: 183 mlab.text3d(x, y, z, symbol, figure=figure, color=(0, 0, 0), scale=0.2) 184 185 if show: mlab.show() 186 return figure 187 188 189def plot_labels(labels, lattice=None, coords_are_cartesian=False, figure=None, **kwargs): # pragma: no cover 190 """ 191 Adds labels to a mayavi_ figure. 192 193 Args: 194 labels: dict containing the label as a key and the coordinates as value. 195 lattice: |Lattice| object used to convert from reciprocal to cartesian coordinates 196 coords_are_cartesian: Set to True if you are providing. 197 coordinates in cartesian coordinates. Defaults to False. 198 Requires lattice if False. 199 figure: mayavi figure, None to plot on the curretn figure 200 kwargs: kwargs passed to the mayavi function `text3d`. Color defaults to blue and size to 25. 201 202 Returns: mayavi figure 203 """ 204 figure, mlab = get_fig_mlab(figure=figure) 205 206 #if "color" not in kwargs: 207 # kwargs["color"] = "b" 208 #if "size" not in kwargs: 209 # kwargs["size"] = 25 210 #if "width" not in kwargs: 211 # kwargs["width"] = 0.8 212 if "scale" not in kwargs: 213 kwargs["scale"] = 0.1 214 215 for k, coords in labels.items(): 216 label = k 217 if k.startswith("\\") or k.find("_") != -1: 218 label = "$" + k + "$" 219 off = 0.01 220 if coords_are_cartesian: 221 coords = np.array(coords) 222 else: 223 if lattice is None: 224 raise ValueError("coords_are_cartesian False requires the lattice") 225 coords = lattice.get_cartesian_coords(coords) 226 x, y, z = coords + off 227 mlab.text3d(x, y, z, label, figure=figure, **kwargs) 228 229 return figure 230 231 232class MayaviFieldAnimator(object): # pragma: no cover 233 234 def __init__(self, filepaths): 235 self.filepaths = filepaths 236 self.num_files = len(filepaths) 237 238 def volume_animate(self): 239 from abipy import abilab 240 with abilab.abiopen(self.filepaths[0]) as nc: 241 nsppol, nspden, nspinor = nc.nsppol, nc.nspden, nc.nspinor 242 structure = nc.structure 243 datar = nc.field.datar 244 # [nspden, nx, ny, nz] array 245 nx, ny, nz = datar.shape[1:] 246 s = datar[0] 247 print(s.dtype, s.shape) 248 249 #cart_coords = np.empty((nx*ny*nz, 3)) 250 #cnt = 0 251 #for i in range(nx): 252 # for j in range(ny): 253 # for k in range(nz): 254 # cart_coords[ctn, :] = (i/nx, j/ny, k/nz) 255 # cnt += 1 256 #cart_coords = structure.lattice.get_cartesian_coords(cart_coords) 257 # We reorder the points, scalars and vectors so this is as per VTK's 258 # requirement of x first, y next and z last. 259 #pts = pts.transpose(2, 1, 0, 3).copy() 260 #pts.shape = pts.size / 3, 3 261 #scalars = scalars.T.copy() 262 #vectors = vectors.transpose(2, 1, 0, 3).copy() 263 #vectors.shape = vectors.size / 3, 3 264 265 #from tvtk.api import tvtk 266 #sgrid = tvtk.StructuredGrid(dimensions=(dims[1], dims[0], dims[2])) 267 #sgrid.points = pts 268 #s = random.random((dims[0]*dims[1]*dims[2])) 269 #sgrid.point_data.scalars = ravel(s.copy()) 270 #sgrid.point_data.scalars.name = 'scalars' 271 272 figure, mlab = get_fig_mlab(figure=None) 273 source = mlab.pipeline.scalar_field(s) 274 data_min, data_max = s.min(), s.max() 275 print(data_min, data_max) 276 #mlab.pipeline.volume(source) 277 # #vmin=data_min + 0.65 * (data_max - data_min), 278 # #vmax=data_min + 0.9 * (data_max - data_min)) 279 #mlab.pipeline.iso_surface(source) 280 mlab.pipeline.image_plane_widget(source, plane_orientation='x_axes', slice_index=0) 281 mlab.pipeline.image_plane_widget(source, plane_orientation='y_axes', slice_index=0) 282 mlab.pipeline.image_plane_widget(source, plane_orientation='z_axes', slice_index=0) 283 284 @mlab.show 285 @mlab.animate(delay=1000, ui=True) 286 def anim(): 287 """Animate.""" 288 t = 1 289 while True: 290 #vmin, vmax = .1 * np.max(data[t]), .2 * np.max(data[t]) 291 #print 'animation t = ',tax[t],', max = ',np.max(data[t]) 292 with abilab.abiopen(self.filepaths[t]) as nc: 293 print("Animation step", t, "from file:", self.filepaths[t]) 294 #nsppol, nspden, nspinor = nc.nsppol, nc.nspden, nc.nspinor 295 datar = nc.field.datar 296 # [nspden, nx, ny, nz] array 297 #nx, ny, nz = datar.shape[1:] 298 scalars = datar[0] 299 300 #data_min, data_max = scalars.min(), scalars.max(), 301 #mlab.pipeline.volume(source, vmin=data_min + 0.65 * (data_max - data_min), 302 # vmax=data_min + 0.9 * (data_max - data_min)) 303 source.mlab_source.scalars = scalars 304 305 t = (t + 1) % self.num_files 306 yield 307 308 anim() 309