1import numpy as np
2
3
4fast = False
5
6
7def write_vti(filename, atoms, data=None):
8    from vtk import vtkStructuredPoints, vtkDoubleArray, vtkXMLImageDataWriter
9
10    # if isinstance(fileobj, str):
11    #     fileobj = paropen(fileobj, 'w')
12
13    if isinstance(atoms, list):
14        if len(atoms) > 1:
15            raise ValueError('Can only write one configuration to a VTI file!')
16        atoms = atoms[0]
17
18    if data is None:
19        raise ValueError('VTK XML Image Data (VTI) format requires data!')
20
21    data = np.asarray(data)
22
23    if data.dtype == complex:
24        data = np.abs(data)
25
26    cell = atoms.get_cell()
27
28    if not np.all(cell == np.diag(np.diag(cell))):
29        raise ValueError('Unit cell must be orthogonal')
30
31    bbox = np.array(list(zip(np.zeros(3), cell.diagonal()))).ravel()
32
33    # Create a VTK grid of structured points
34    spts = vtkStructuredPoints()
35    spts.SetWholeBoundingBox(bbox)
36    spts.SetDimensions(data.shape)
37    spts.SetSpacing(cell.diagonal() / data.shape)
38    # spts.SetSpacing(paw.gd.h_c * Bohr)
39
40    # print('paw.gd.h_c * Bohr=',paw.gd.h_c * Bohr)
41    # print('atoms.cell.diagonal() / data.shape=', cell.diagonal()/data.shape)
42    # assert np.all(paw.gd.h_c * Bohr==cell.diagonal()/data.shape)
43
44    # s = paw.wfs.kpt_u[0].psit_nG[0].copy()
45    # data = paw.get_pseudo_wave_function(band=0, kpt=0, spin=0, pad=False)
46    # spts.point_data.scalars = data.swapaxes(0,2).flatten()
47    # spts.point_data.scalars.name = 'scalars'
48
49    # Allocate a VTK array of type double and copy data
50    da = vtkDoubleArray()
51    da.SetName('scalars')
52    da.SetNumberOfComponents(1)
53    da.SetNumberOfTuples(np.prod(data.shape))
54
55    for i, d in enumerate(data.swapaxes(0, 2).flatten()):
56        da.SetTuple1(i, d)
57
58    # Assign the VTK array as point data of the grid
59    spd = spts.GetPointData()  # type(spd) is vtkPointData
60    spd.SetScalars(da)
61
62    """
63    from vtk.util.vtkImageImportFromArray import vtkImageImportFromArray
64    iia = vtkImageImportFromArray()
65    #iia.SetArray(Numeric_asarray(data.swapaxes(0,2).flatten()))
66    iia.SetArray(Numeric_asarray(data))
67    ida = iia.GetOutput()
68    ipd = ida.GetPointData()
69    ipd.SetName('scalars')
70    spd.SetScalars(ipd.GetScalars())
71    """
72
73    # Save the ImageData dataset to a VTK XML file.
74    w = vtkXMLImageDataWriter()
75
76    if fast:
77        w.SetDataModeToAppend()
78        w.EncodeAppendedDataOff()
79    else:
80        w.SetDataModeToAscii()
81
82    w.SetFileName(filename)
83    w.SetInput(spts)
84    w.Write()
85
86
87def write_vtu(filename, atoms, data=None):
88    from vtk import (VTK_MAJOR_VERSION, vtkUnstructuredGrid, vtkPoints,
89                     vtkXMLUnstructuredGridWriter)
90    from vtk.util.numpy_support import numpy_to_vtk
91
92    if isinstance(atoms, list):
93        if len(atoms) > 1:
94            raise ValueError('Can only write one configuration to a VTI file!')
95        atoms = atoms[0]
96
97    # Create a VTK grid of structured points
98    ugd = vtkUnstructuredGrid()
99
100    # add atoms as vtk Points
101    p = vtkPoints()
102    p.SetNumberOfPoints(len(atoms))
103    p.SetDataTypeToDouble()
104    for i, pos in enumerate(atoms.get_positions()):
105        p.InsertPoint(i, *pos)
106    ugd.SetPoints(p)
107
108    # add atomic numbers
109    numbers = numpy_to_vtk(atoms.get_atomic_numbers(), deep=1)
110    ugd.GetPointData().AddArray(numbers)
111    numbers.SetName("atomic numbers")
112
113    # add tags
114    tags = numpy_to_vtk(atoms.get_tags(), deep=1)
115    ugd.GetPointData().AddArray(tags)
116    tags.SetName("tags")
117
118    # add covalent radii
119    from ase.data import covalent_radii
120    radii = numpy_to_vtk(covalent_radii[atoms.numbers], deep=1)
121    ugd.GetPointData().AddArray(radii)
122    radii.SetName("radii")
123
124    # Save the UnstructuredGrid dataset to a VTK XML file.
125    w = vtkXMLUnstructuredGridWriter()
126
127    if fast:
128        w.SetDataModeToAppend()
129        w.EncodeAppendedDataOff()
130    else:
131        w.GetCompressor().SetCompressionLevel(0)
132        w.SetDataModeToAscii()
133
134    if isinstance(filename, str):
135        w.SetFileName(filename)
136    else:
137        w.SetFileName(filename.name)
138    if VTK_MAJOR_VERSION <= 5:
139        w.SetInput(ugd)
140    else:
141        w.SetInputData(ugd)
142    w.Write()
143