1"""
2vtkImageImportFromArray: a NumPy front-end to vtkImageImport
3
4Load a python array into a vtk image.
5To use this class, you must have NumPy installed (http://numpy.scipy.org/)
6
7Methods:
8
9  SetArray()  -- set the numpy array to load
10  Update()    -- generate the output
11  GetOutput() -- get the image as vtkImageData
12  GetOutputPort() -- connect to VTK pipeline
13
14Methods from vtkImageImport:
15(if you don't set these, sensible defaults will be used)
16
17  SetDataExtent()
18  SetDataSpacing()
19  SetDataOrigin()
20"""
21
22from vtk import vtkImageImport
23from vtk import VTK_SIGNED_CHAR
24from vtk import VTK_UNSIGNED_CHAR
25from vtk import VTK_SHORT
26from vtk import VTK_UNSIGNED_SHORT
27from vtk import VTK_INT
28from vtk import VTK_UNSIGNED_INT
29from vtk import VTK_LONG
30from vtk import VTK_UNSIGNED_LONG
31from vtk import VTK_FLOAT
32from vtk import VTK_DOUBLE
33
34class vtkImageImportFromArray:
35    def __init__(self):
36        self.__import = vtkImageImport()
37        self.__ConvertIntToUnsignedShort = False
38        self.__Array = None
39
40    # type dictionary: note that python doesn't support
41    # unsigned integers properly!
42    __typeDict = {'b':VTK_SIGNED_CHAR,     # int8
43                  'B':VTK_UNSIGNED_CHAR,   # uint8
44                  'h':VTK_SHORT,           # int16
45                  'H':VTK_UNSIGNED_SHORT,  # uint16
46                  'i':VTK_INT,             # int32
47                  'I':VTK_UNSIGNED_INT,    # uint32
48                  'f':VTK_FLOAT,           # float32
49                  'd':VTK_DOUBLE,          # float64
50                  'F':VTK_FLOAT,           # float32
51                  'D':VTK_DOUBLE,          # float64
52                  }
53
54    __sizeDict = { VTK_SIGNED_CHAR:1,
55                   VTK_UNSIGNED_CHAR:1,
56                   VTK_SHORT:2,
57                   VTK_UNSIGNED_SHORT:2,
58                   VTK_INT:4,
59                   VTK_UNSIGNED_INT:4,
60                   VTK_FLOAT:4,
61                   VTK_DOUBLE:8 }
62
63    # convert 'Int32' to 'unsigned short'
64    def SetConvertIntToUnsignedShort(self,yesno):
65        self.__ConvertIntToUnsignedShort = yesno
66
67    def GetConvertIntToUnsignedShort(self):
68        return self.__ConvertIntToUnsignedShort
69
70    def ConvertIntToUnsignedShortOn(self):
71        self.__ConvertIntToUnsignedShort = True
72
73    def ConvertIntToUnsignedShortOff(self):
74        self.__ConvertIntToUnsignedShort = False
75
76    def Update(self):
77        self.__import.Update()
78
79    # get the output
80    def GetOutputPort(self):
81        return self.__import.GetOutputPort()
82
83    # get the output
84    def GetOutput(self):
85        return self.__import.GetOutput()
86
87    # import an array
88    def SetArray(self,imArray):
89        self.__Array = imArray
90        numComponents = 1
91        dim = imArray.shape
92        if len(dim) == 0:
93            dim = (1,1,1)
94        elif len(dim) == 1:
95            dim = (1, 1, dim[0])
96        elif len(dim) == 2:
97            dim = (1, dim[0], dim[1])
98        elif len(dim) == 4:
99            numComponents = dim[3]
100            dim = (dim[0],dim[1],dim[2])
101
102        typecode = imArray.dtype.char
103
104        ar_type = self.__typeDict[typecode]
105
106        complexComponents = 1
107        if (typecode == 'F' or typecode == 'D'):
108            numComponents = numComponents * 2
109            complexComponents = 2
110
111        if (self.__ConvertIntToUnsignedShort and typecode == 'i'):
112            imArray = imArray.astype('h')
113            ar_type = VTK_UNSIGNED_SHORT
114
115        size = len(imArray.flat)*self.__sizeDict[ar_type]*complexComponents
116        self.__import.CopyImportVoidPointer(imArray, size)
117        self.__import.SetDataScalarType(ar_type)
118        self.__import.SetNumberOfScalarComponents(numComponents)
119        extent = self.__import.GetDataExtent()
120        self.__import.SetDataExtent(extent[0],extent[0]+dim[2]-1,
121                                    extent[2],extent[2]+dim[1]-1,
122                                    extent[4],extent[4]+dim[0]-1)
123        self.__import.SetWholeExtent(extent[0],extent[0]+dim[2]-1,
124                                     extent[2],extent[2]+dim[1]-1,
125                                     extent[4],extent[4]+dim[0]-1)
126
127    def GetArray(self):
128        return self.__Array
129
130    # a whole bunch of methods copied from vtkImageImport
131
132    def SetDataExtent(self,extent):
133        self.__import.SetDataExtent(extent)
134
135    def GetDataExtent(self):
136        return self.__import.GetDataExtent()
137
138    def SetDataSpacing(self,spacing):
139        self.__import.SetDataSpacing(spacing)
140
141    def GetDataSpacing(self):
142        return self.__import.GetDataSpacing()
143
144    def SetDataOrigin(self,origin):
145        self.__import.SetDataOrigin(origin)
146
147    def GetDataOrigin(self):
148        return self.__import.GetDataOrigin()
149