1"""
2vtkImageExportToArray - a NumPy front-end to vtkImageExport
3
4This class converts a VTK image to a numpy array.  The output
5array will always have 3 dimensions (or 4, if the image had
6multiple scalar components).
7
8To use this class, you must have numpy installed (http://numpy.scipy.org)
9
10Methods
11
12  SetInputConnection(vtkAlgorithmOutput) -- connect to VTK image pipeline
13  SetInputData(vtkImageData) -- set an vtkImageData to export
14  GetArray() -- execute pipeline and return a numpy array
15
16Methods from vtkImageExport
17
18  GetDataExtent()
19  GetDataSpacing()
20  GetDataOrigin()
21"""
22
23import numpy
24import numpy.core.umath as umath
25
26from vtk import vtkImageExport
27from vtk import vtkStreamingDemandDrivenPipeline
28from vtk import VTK_SIGNED_CHAR
29from vtk import VTK_UNSIGNED_CHAR
30from vtk import VTK_SHORT
31from vtk import VTK_UNSIGNED_SHORT
32from vtk import VTK_INT
33from vtk import VTK_UNSIGNED_INT
34from vtk import VTK_LONG
35from vtk import VTK_UNSIGNED_LONG
36from vtk import VTK_FLOAT
37from vtk import VTK_DOUBLE
38
39
40class vtkImageExportToArray:
41    def __init__(self):
42        self.__export = vtkImageExport()
43        self.__ConvertUnsignedShortToInt = False
44
45    # type dictionary
46
47    __typeDict = { VTK_SIGNED_CHAR:'b',
48                   VTK_UNSIGNED_CHAR:'B',
49                   VTK_SHORT:'h',
50                   VTK_UNSIGNED_SHORT:'H',
51                   VTK_INT:'i',
52                   VTK_UNSIGNED_INT:'I',
53                   VTK_FLOAT:'f',
54                   VTK_DOUBLE:'d'}
55
56    __sizeDict = { VTK_SIGNED_CHAR:1,
57                   VTK_UNSIGNED_CHAR:1,
58                   VTK_SHORT:2,
59                   VTK_UNSIGNED_SHORT:2,
60                   VTK_INT:4,
61                   VTK_UNSIGNED_INT:4,
62                   VTK_FLOAT:4,
63                   VTK_DOUBLE:8 }
64
65    # convert unsigned shorts to ints, to avoid sign problems
66    def SetConvertUnsignedShortToInt(self,yesno):
67        self.__ConvertUnsignedShortToInt = yesno
68
69    def GetConvertUnsignedShortToInt(self):
70        return self.__ConvertUnsignedShortToInt
71
72    def ConvertUnsignedShortToIntOn(self):
73        self.__ConvertUnsignedShortToInt = True
74
75    def ConvertUnsignedShortToIntOff(self):
76        self.__ConvertUnsignedShortToInt = False
77
78    # set the input
79    def SetInputConnection(self,input):
80        return self.__export.SetInputConnection(input)
81
82    def SetInputData(self,input):
83        return self.__export.SetInputData(input)
84
85    def GetInput(self):
86        return self.__export.GetInput()
87
88    def GetArray(self):
89        self.__export.Update()
90        input = self.__export.GetInput()
91        extent = input.GetExtent()
92        type = input.GetScalarType()
93        numComponents = input.GetNumberOfScalarComponents()
94        dim = (extent[5]-extent[4]+1,
95               extent[3]-extent[2]+1,
96               extent[1]-extent[0]+1)
97        if (numComponents > 1):
98            dim = dim + (numComponents,)
99
100        imArray = numpy.zeros(dim, self.__typeDict[type])
101        self.__export.Export(imArray)
102
103        # convert unsigned short to int to avoid sign issues
104        if (type == VTK_UNSIGNED_SHORT and self.__ConvertUnsignedShortToInt):
105            imArray = umath.bitwise_and(imArray.astype('i'),0xffff)
106
107        return imArray
108
109    def GetDataExtent(self):
110        return self.__export.GetDataExtent()
111
112    def GetDataSpacing(self):
113        return self.__export.GetDataSpacing()
114
115    def GetDataOrigin(self):
116        return self.__export.GetDataOrigin()
117