1"""This module provides classes that allow Numpy-type access
2to VTK datasets and arrays. This is best described with some examples.
3
4To normalize a VTK array:
5
6from vtkmodules.vtkImagingCore vtkRTAnalyticSource
7import vtkmodules.numpy_interface.dataset_adapter as dsa
8import vtkmodules.numpy_interface.algorithms as algs
9
10rt = vtkRTAnalyticSource()
11rt.Update()
12image = dsa.WrapDataObject(rt.GetOutput())
13rtdata = image.PointData['RTData']
14rtmin = algs.min(rtdata)
15rtmax = algs.max(rtdata)
16rtnorm = (rtdata - rtmin) / (rtmax - rtmin)
17image.PointData.append(rtnorm, 'RTData - normalized')
18print image.GetPointData().GetArray('RTData - normalized').GetRange()
19
20To calculate gradient:
21
22grad= algs.gradient(rtnorm)
23
24To access subsets:
25
26>>> grad[0:10]
27VTKArray([[ 0.10729134,  0.03763443,  0.03136338],
28       [ 0.02754352,  0.03886006,  0.032589  ],
29       [ 0.02248248,  0.04127144,  0.03500038],
30       [ 0.02678365,  0.04357527,  0.03730421],
31       [ 0.01765099,  0.04571581,  0.03944477],
32       [ 0.02344007,  0.04763837,  0.04136734],
33       [ 0.01089381,  0.04929155,  0.04302051],
34       [ 0.01769151,  0.05062952,  0.04435848],
35       [ 0.002764  ,  0.05161414,  0.04534309],
36       [ 0.01010841,  0.05221677,  0.04594573]])
37
38>>> grad[:, 0]
39VTKArray([ 0.10729134,  0.02754352,  0.02248248, ..., -0.02748174,
40       -0.02410045,  0.05509736])
41
42All of this functionality is also supported for composite datasets
43even though their data arrays may be spread across multiple datasets.
44We have implemented a VTKCompositeDataArray class that handles many
45Numpy style operators and is supported by all algorithms in the
46algorithms module.
47
48This module also provides an API to access composite datasets.
49For example:
50
51from vtkmodules.vtkCommonDataModel import vtkMultiBlockDataSet
52mb = vtkMultiBlockDataSet()
53mb.SetBlock(0, image.VTKObject)
54mb.SetBlock(1e, image.VTKObject)
55cds = dsa.WrapDataObject(mb)
56for block in cds:
57    print block
58
59Note that this module implements only the wrappers for datasets
60and arrays. The classes implement many useful operators. However,
61to make best use of these classes, take a look at the algorithms
62module.
63"""
64try:
65    import numpy
66except ImportError:
67    raise RuntimeError("This module depends on the numpy module. Please make\
68sure that it is installed properly.")
69
70import itertools
71import operator
72import sys
73from ..vtkCommonCore import buffer_shared
74from ..util import numpy_support
75from ..vtkCommonDataModel import vtkDataObject
76from ..vtkCommonCore import vtkWeakReference
77import weakref
78
79if sys.hexversion < 0x03000000:
80    izip = itertools.izip
81else:
82    izip = zip
83
84def reshape_append_ones (a1, a2):
85    """Returns a list with the two arguments, any of them may be
86    processed.  If the arguments are numpy.ndarrays, append 1s to the
87    shape of the array with the smallest number of dimensions until
88    the arrays have the same number of dimensions. Does nothing if the
89    arguments are not ndarrays or the arrays have the same number of
90    dimensions.
91
92    """
93    l = [a1, a2]
94    if (isinstance(a1, numpy.ndarray) and isinstance(a2, numpy.ndarray)):
95        len1 = len(a1.shape)
96        len2 = len(a2.shape)
97        if (len1 == len2 or len1 == 0 or len2 == 0 or
98            a1.shape[0] != a2.shape[0]):
99            return l;
100        elif (len1 < len2):
101            d = len1
102            maxLength = len2
103            i = 0
104        else:
105            d = len2
106            maxLength = len1
107            i = 1
108        while (d < maxLength):
109            l[i] = numpy.expand_dims(l[i], d)
110            d = d + 1
111    return l
112
113class ArrayAssociation :
114    """Easy access to vtkDataObject.AttributeTypes"""
115    POINT = vtkDataObject.POINT
116    CELL  = vtkDataObject.CELL
117    FIELD = vtkDataObject.FIELD
118    ROW = vtkDataObject.ROW
119
120class VTKObjectWrapper(object):
121    """Superclass for classes that wrap VTK objects with Python objects.
122    This class holds a reference to the wrapped VTK object. It also
123    forwards unresolved methods to the underlying object by overloading
124    __get__attr."""
125    def __init__(self, vtkobject):
126        self.VTKObject = vtkobject
127
128    def __getattr__(self, name):
129        "Forwards unknown attribute requests to VTK object."
130        return getattr(self.VTKObject, name)
131
132def vtkDataArrayToVTKArray(array, dataset=None):
133    "Given a vtkDataArray and a dataset owning it, returns a VTKArray."
134    narray = numpy_support.vtk_to_numpy(array)
135
136    # Make arrays of 9 components into matrices. Also transpose
137    # as VTK store matrices in Fortran order
138    shape = narray.shape
139    if len(shape) == 2 and shape[1] == 9:
140        narray = narray.reshape((shape[0], 3, 3)).transpose(0, 2, 1)
141
142    return VTKArray(narray, array=array, dataset=dataset)
143
144def numpyTovtkDataArray(array, name="numpy_array", array_type=None):
145    """Given a numpy array or a VTKArray and a name, returns a vtkDataArray.
146    The resulting vtkDataArray will store a reference to the numpy array:
147    the numpy array is released only when the vtkDataArray is destroyed."""
148    vtkarray = numpy_support.numpy_to_vtk(array, array_type=array_type)
149    vtkarray.SetName(name)
150    return vtkarray
151
152def _make_tensor_array_contiguous(array):
153    if array is None:
154        return None
155    if array.flags.contiguous:
156        return array
157    array = numpy.asarray(array)
158    size = array.dtype.itemsize
159    strides = array.strides
160    if len(strides) == 3 and strides[1]/size == 1 and strides[2]/size == 3:
161        return array.transpose(0, 2, 1)
162    return array
163
164def _metaclass(mcs):
165    """For compatibility between python 2 and python 3."""
166    def decorator(cls):
167        body = vars(cls).copy()
168        body.pop('__dict__', None)
169        body.pop('__weakref__', None)
170        return mcs(cls.__name__, cls.__bases__, body)
171    return decorator
172
173class VTKArrayMetaClass(type):
174    def __new__(mcs, name, parent, attr):
175        """We overwrite numerical/comparison operators because we might need
176        to reshape one of the arrays to perform the operation without
177        broadcast errors. For instance:
178
179        An array G of shape (n,3) resulted from computing the
180        gradient on a scalar array S of shape (n,) cannot be added together without
181        reshaping.
182        G + expand_dims(S,1) works,
183        G + S gives an error:
184        ValueError: operands could not be broadcast together with shapes (n,3) (n,)
185
186        This metaclass overwrites operators such that it computes this
187        reshape operation automatically by appending 1s to the
188        dimensions of the array with fewer dimensions.
189
190        """
191        def add_numeric_op(attr_name):
192            """Create an attribute named attr_name that calls
193            _numeric_op(self, other, op)."""
194            def closure(self, other):
195                return VTKArray._numeric_op(self, other, attr_name)
196            closure.__name__ = attr_name
197            attr[attr_name] = closure
198
199        def add_default_numeric_op(op_name):
200            """Adds '__[op_name]__' attribute that uses operator.[op_name]"""
201            add_numeric_op("__%s__"%op_name)
202
203        def add_reverse_numeric_op(attr_name):
204            """Create an attribute named attr_name that calls
205            _reverse_numeric_op(self, other, op)."""
206            def closure(self, other):
207                return VTKArray._reverse_numeric_op(self, other, attr_name)
208            closure.__name__ = attr_name
209            attr[attr_name] = closure
210
211        def add_default_reverse_numeric_op(op_name):
212            """Adds '__r[op_name]__' attribute that uses operator.[op_name]"""
213            add_reverse_numeric_op("__r%s__"%op_name)
214
215        def add_default_numeric_ops(op_name):
216            """Call both add_default_numeric_op and add_default_reverse_numeric_op."""
217            add_default_numeric_op(op_name)
218            add_default_reverse_numeric_op(op_name)
219
220        add_default_numeric_ops("add")
221        add_default_numeric_ops("sub")
222        add_default_numeric_ops("mul")
223        if sys.hexversion < 0x03000000:
224            add_default_numeric_ops("div")
225        add_default_numeric_ops("truediv")
226        add_default_numeric_ops("floordiv")
227        add_default_numeric_ops("mod")
228        add_default_numeric_ops("pow")
229        add_default_numeric_ops("lshift")
230        add_default_numeric_ops("rshift")
231        add_numeric_op("and")
232        add_default_numeric_ops("xor")
233        add_numeric_op("or")
234
235        add_default_numeric_op("lt")
236        add_default_numeric_op("le")
237        add_default_numeric_op("eq")
238        add_default_numeric_op("ne")
239        add_default_numeric_op("ge")
240        add_default_numeric_op("gt")
241        return type.__new__(mcs, name, parent, attr)
242
243@_metaclass(VTKArrayMetaClass)
244class VTKArray(numpy.ndarray):
245    """This is a sub-class of numpy ndarray that stores a
246    reference to a vtk array as well as the owning dataset.
247    The numpy array and vtk array should point to the same
248    memory location."""
249
250    def _numeric_op(self, other, attr_name):
251        """Used to implement numpy-style numerical operations such as __add__,
252        __mul__, etc."""
253        l = reshape_append_ones(self, other)
254        return getattr(numpy.ndarray, attr_name)(l[0], l[1])
255
256    def _reverse_numeric_op(self, other, attr_name):
257        """Used to implement numpy-style numerical operations such as __add__,
258        __mul__, etc."""
259        l = reshape_append_ones(self, other)
260        return getattr(numpy.ndarray, attr_name)(l[0], l[1])
261
262    def __new__(cls, input_array, array=None, dataset=None):
263        # Input array is an already formed ndarray instance
264        # We first cast to be our class type
265        obj = numpy.asarray(input_array).view(cls)
266        obj.Association = ArrayAssociation.FIELD
267        # add the new attributes to the created instance
268        obj.VTKObject = array
269        if dataset:
270            obj._dataset = vtkWeakReference()
271            obj._dataset.Set(dataset.VTKObject)
272        # Finally, we must return the newly created object:
273        return obj
274
275    def __array_finalize__(self,obj):
276        # Copy the VTK array only if the two share data
277        slf = _make_tensor_array_contiguous(self)
278        obj2 = _make_tensor_array_contiguous(obj)
279
280        self.VTKObject = None
281        try:
282            # This line tells us that they are referring to the same buffer.
283            # Much like two pointers referring to same memory location in C/C++.
284            if buffer_shared(slf, obj2):
285                self.VTKObject = getattr(obj, 'VTKObject', None)
286        except TypeError:
287            pass
288
289        self.Association = getattr(obj, 'Association', None)
290        self.DataSet = getattr(obj, 'DataSet', None)
291
292    def __getattr__(self, name):
293        "Forwards unknown attribute requests to VTK array."
294        try:
295            o = self.__dict__["VTKObject"]
296        except KeyError:
297            o = None
298        if o is None:
299            raise AttributeError("'%s' object has no attribute '%s'" %
300                                 (self.__class__.__name__, name))
301        return getattr(o, name)
302
303    def __array_wrap__(self, out_arr, context=None):
304        if out_arr.shape == ():
305            # Convert to scalar value
306            return out_arr[()]
307        else:
308            return numpy.ndarray.__array_wrap__(self, out_arr, context)
309
310    @property
311    def DataSet(self):
312        """
313        Get the dataset this array is associated with. The reference to the
314        dataset is held through a vtkWeakReference to ensure it doesn't prevent
315        the dataset from being collected if necessary.
316        """
317        if hasattr(self, '_dataset') and self._dataset and self._dataset.Get():
318            return WrapDataObject(self._dataset.Get())
319
320        return  None
321
322    @DataSet.setter
323    def DataSet(self, dataset):
324        """
325        Set the dataset this array is associated with. The reference is held
326        through a vtkWeakReference.
327        """
328        # Do we have dataset to store
329        if dataset and dataset.VTKObject:
330            # Do we need to create a vtkWeakReference
331            if not hasattr(self, '_dataset') or self._dataset is None:
332                self._dataset = vtkWeakReference()
333
334            self._dataset.Set(dataset.VTKObject)
335        else:
336            self._dataset = None
337
338class VTKNoneArrayMetaClass(type):
339    def __new__(mcs, name, parent, attr):
340        """Simplify the implementation of the numeric/logical sequence API."""
341        def _add_op(attr_name, op):
342            """Create an attribute named attr_name that calls
343            _numeric_op(self, other, op)."""
344            def closure(self, other):
345                return VTKNoneArray._op(self, other, op)
346            closure.__name__ = attr_name
347            attr[attr_name] = closure
348
349        def _add_default_reverse_op(op_name):
350            """Adds '__r[op_name]__' attribute that uses operator.[op_name]"""
351            _add_op("__r%s__"%op_name, getattr(operator, op_name))
352
353        def _add_default_op(op_name):
354            """Adds '__[op_name]__' attribute that uses operator.[op_name]"""
355            _add_op("__%s__"%op_name, getattr(operator, op_name))
356
357        def _add_default_ops(op_name):
358            """Call both add_default_numeric_op and add_default_reverse_numeric_op."""
359            _add_default_op(op_name)
360            _add_default_reverse_op(op_name)
361
362        _add_default_ops("add")
363        _add_default_ops("sub")
364        _add_default_ops("mul")
365        if sys.hexversion < 0x03000000:
366            _add_default_ops("div")
367        _add_default_ops("truediv")
368        _add_default_ops("floordiv")
369        _add_default_ops("mod")
370        _add_default_ops("pow")
371        _add_default_ops("lshift")
372        _add_default_ops("rshift")
373        _add_op("__and__", operator.and_)
374        _add_op("__rand__", operator.and_)
375        _add_default_ops("xor")
376        _add_op("__or__", operator.or_)
377        _add_op("__ror__", operator.or_)
378
379        _add_default_op("lt")
380        _add_default_op("le")
381        _add_default_op("eq")
382        _add_default_op("ne")
383        _add_default_op("ge")
384        _add_default_op("gt")
385        return type.__new__(mcs, name, parent, attr)
386
387@_metaclass(VTKNoneArrayMetaClass)
388class VTKNoneArray(object):
389    """VTKNoneArray is used to represent a "void" array. An instance
390    of this class (NoneArray) is returned instead of None when an
391    array that doesn't exist in a DataSetAttributes is requested.
392    All operations on the NoneArray return NoneArray. The main reason
393    for this is to support operations in parallel where one of the
394    processes may be working on an empty dataset. In such cases,
395    the process is still expected to evaluate a whole expression because
396    some of the functions may perform bulk MPI communication. None
397    cannot be used in these instances because it cannot properly override
398    operators such as __add__, __sub__ etc. This is the main raison
399    d'etre for VTKNoneArray."""
400
401    def __getitem__(self, index):
402        return NoneArray
403
404    def _op(self, other, op):
405        """Used to implement numpy-style numerical operations such as __add__,
406        __mul__, etc."""
407        return NoneArray
408
409    def astype(self, dtype):
410        """Implements numpy array's astype method."""
411        return NoneArray
412
413NoneArray = VTKNoneArray()
414
415class VTKCompositeDataArrayMetaClass(type):
416    def __new__(mcs, name, parent, attr):
417        """Simplify the implementation of the numeric/logical sequence API."""
418        def add_numeric_op(attr_name, op):
419            """Create an attribute named attr_name that calls
420            _numeric_op(self, other, op)."""
421            def closure(self, other):
422                return VTKCompositeDataArray._numeric_op(self, other, op)
423            closure.__name__ = attr_name
424            attr[attr_name] = closure
425
426        def add_reverse_numeric_op(attr_name, op):
427            """Create an attribute named attr_name that calls
428            _reverse_numeric_op(self, other, op)."""
429            def closure(self, other):
430                return VTKCompositeDataArray._reverse_numeric_op(self, other, op)
431            closure.__name__ = attr_name
432            attr[attr_name] = closure
433
434        def add_default_reverse_numeric_op(op_name):
435            """Adds '__r[op_name]__' attribute that uses operator.[op_name]"""
436            add_reverse_numeric_op("__r%s__"%op_name, getattr(operator, op_name))
437
438        def add_default_numeric_op(op_name):
439            """Adds '__[op_name]__' attribute that uses operator.[op_name]"""
440            add_numeric_op("__%s__"%op_name, getattr(operator, op_name))
441
442        def add_default_numeric_ops(op_name):
443            """Call both add_default_numeric_op and add_default_reverse_numeric_op."""
444            add_default_numeric_op(op_name)
445            add_default_reverse_numeric_op(op_name)
446
447        add_default_numeric_ops("add")
448        add_default_numeric_ops("sub")
449        add_default_numeric_ops("mul")
450        if sys.hexversion < 0x03000000:
451            add_default_numeric_ops("div")
452        add_default_numeric_ops("truediv")
453        add_default_numeric_ops("floordiv")
454        add_default_numeric_ops("mod")
455        add_default_numeric_ops("pow")
456        add_default_numeric_ops("lshift")
457        add_default_numeric_ops("rshift")
458        add_numeric_op("__and__", operator.and_)
459        add_reverse_numeric_op("__rand__", operator.and_)
460        add_default_numeric_ops("xor")
461        add_numeric_op("__or__", operator.or_)
462        add_reverse_numeric_op("__ror__", operator.or_)
463
464        add_default_numeric_op("lt")
465        add_default_numeric_op("le")
466        add_default_numeric_op("eq")
467        add_default_numeric_op("ne")
468        add_default_numeric_op("ge")
469        add_default_numeric_op("gt")
470        return type.__new__(mcs, name, parent, attr)
471
472@_metaclass(VTKCompositeDataArrayMetaClass)
473class VTKCompositeDataArray(object):
474    """This class manages a set of arrays of the same name contained
475    within a composite dataset. Its main purpose is to provide a
476    Numpy-type interface to composite data arrays which are naturally
477    nothing but a collection of vtkDataArrays. A VTKCompositeDataArray
478    makes such a collection appear as a single Numpy array and support
479    all array operations that this module and the associated algorithm
480    module support. Note that this is not a subclass of a Numpy array
481    and as such cannot be passed to native Numpy functions. Instead
482    VTK modules should be used to process composite arrays.
483    """
484
485    def __init__(self, arrays = [], dataset = None, name = None,
486                 association = None):
487        """Construct a composite array given a container of
488        arrays, a dataset, name and association. It is sufficient
489        to define a container of arrays to define a composite array.
490        It is also possible to initialize an array by defining
491        the dataset, name and array association. In that case,
492        the underlying arrays will be created lazily when they
493        are needed. It is recommended to use the latter method
494        when initializing from an existing composite dataset."""
495        self._Arrays = arrays
496        self.DataSet = dataset
497        self.Name = name
498        validAssociation = True
499        if association == None:
500            for array in self._Arrays:
501                if hasattr(array, "Association"):
502                    if association == None:
503                        association = array.Association
504                    elif array.Association and association != array.Association:
505                        validAssociation = False
506                        break
507        if validAssociation:
508            self.Association = association
509        else:
510            self.Association = ArrayAssociation.FIELD
511        self.Initialized = False
512
513    def __init_from_composite(self):
514        if self.Initialized:
515            return
516
517        self.Initialized = True
518
519        if self.DataSet is None or self.Name is None:
520            return
521
522        self._Arrays = []
523        for ds in self.DataSet:
524            self._Arrays.append(ds.GetAttributes(self.Association)[self.Name])
525
526    def GetSize(self):
527        "Returns the number of elements in the array."
528        self.__init_from_composite()
529        size = numpy.int64(0)
530        for a in self._Arrays:
531            try:
532                size += a.size
533            except AttributeError:
534                pass
535        return size
536
537    size = property(GetSize)
538
539    def GetArrays(self):
540        """Returns the internal container of VTKArrays. If necessary,
541        this will populate the array list from a composite dataset."""
542        self.__init_from_composite()
543        return self._Arrays
544
545    Arrays = property(GetArrays)
546
547    def __getitem__(self, index):
548        """Overwritten to refer indexing to underlying VTKArrays.
549        For the most part, this will behave like Numpy. Note
550        that indexing is done per array - arrays are never treated
551        as forming a bigger array. If the index is another composite
552        array, a one-to-one mapping between arrays is assumed.
553        """
554        self.__init_from_composite()
555        res = []
556        if type(index) == VTKCompositeDataArray:
557            for a, idx in izip(self._Arrays, index.Arrays):
558                if a is not NoneArray:
559                    res.append(a.__getitem__(idx))
560                else:
561                    res.append(NoneArray)
562        else:
563            for a in self._Arrays:
564                if a is not NoneArray:
565                    res.append(a.__getitem__(index))
566                else:
567                    res.append(NoneArray)
568        return VTKCompositeDataArray(res, dataset=self.DataSet)
569
570    def _numeric_op(self, other, op):
571        """Used to implement numpy-style numerical operations such as __add__,
572        __mul__, etc."""
573        self.__init_from_composite()
574        res = []
575        if type(other) == VTKCompositeDataArray:
576            for a1, a2 in izip(self._Arrays, other.Arrays):
577                if a1 is not NoneArray and a2 is not NoneArray:
578                    l = reshape_append_ones(a1, a2)
579                    res.append(op(l[0],l[1]))
580                else:
581                    res.append(NoneArray)
582        else:
583            for a in self._Arrays:
584                if a is not NoneArray:
585                    l = reshape_append_ones(a, other)
586                    res.append(op(l[0], l[1]))
587                else:
588                    res.append(NoneArray)
589        return VTKCompositeDataArray(
590            res, dataset=self.DataSet, association=self.Association)
591
592    def _reverse_numeric_op(self, other, op):
593        """Used to implement numpy-style numerical operations such as __add__,
594        __mul__, etc."""
595        self.__init_from_composite()
596        res = []
597        if type(other) == VTKCompositeDataArray:
598            for a1, a2 in izip(self._Arrays, other.Arrays):
599                if a1 is not NoneArray and a2 is notNoneArray:
600                    l = reshape_append_ones(a2,a1)
601                    res.append(op(l[0],l[1]))
602                else:
603                    res.append(NoneArray)
604        else:
605            for a in self._Arrays:
606                if a is not NoneArray:
607                    l = reshape_append_ones(other, a)
608                    res.append(op(l[0], l[1]))
609                else:
610                    res.append(NoneArray)
611        return VTKCompositeDataArray(
612            res, dataset=self.DataSet, association = self.Association)
613
614    def __str__(self):
615        return self.Arrays.__str__()
616
617    def astype(self, dtype):
618        """Implements numpy array's as array method."""
619        res = []
620        if self is not NoneArray:
621            for a in self.Arrays:
622                if a is NoneArray:
623                    res.append(NoneArray)
624                else:
625                    res.append(a.astype(dtype))
626        return VTKCompositeDataArray(
627            res, dataset = self.DataSet, association = self.Association)
628
629
630class DataSetAttributes(VTKObjectWrapper):
631    """This is a python friendly wrapper of vtkDataSetAttributes. It
632    returns VTKArrays. It also provides the dictionary interface."""
633
634    def __init__(self, vtkobject, dataset, association):
635        super(DataSetAttributes, self).__init__(vtkobject)
636        # import weakref
637        # self.DataSet = weakref.ref(dataset)
638        self.DataSet = dataset
639        self.Association = association
640
641    def __getitem__(self, idx):
642        """Implements the [] operator. Accepts an array name or index."""
643        return self.GetArray(idx)
644
645    def GetArray(self, idx):
646        "Given an index or name, returns a VTKArray."
647        if isinstance(idx, int) and idx >= self.VTKObject.GetNumberOfArrays():
648            raise IndexError("array index out of range")
649        vtkarray = self.VTKObject.GetArray(idx)
650        if not vtkarray:
651            vtkarray = self.VTKObject.GetAbstractArray(idx)
652            if vtkarray:
653                return vtkarray
654            return NoneArray
655        array = vtkDataArrayToVTKArray(vtkarray, self.DataSet)
656        array.Association = self.Association
657        return array
658
659    def keys(self):
660        """Returns the names of the arrays as a list."""
661        kys = []
662        narrays = self.VTKObject.GetNumberOfArrays()
663        for i in range(narrays):
664            name = self.VTKObject.GetAbstractArray(i).GetName()
665            if name:
666                kys.append(name)
667        return kys
668
669    def values(self):
670        """Returns the arrays as a list."""
671        vals = []
672        narrays = self.VTKObject.GetNumberOfArrays()
673        for i in range(narrays):
674            a = self.VTKObject.GetAbstractArray(i)
675            if a.GetName():
676                vals.append(a)
677        return vals
678
679    def PassData(self, other):
680        "A wrapper for vtkDataSet.PassData."
681        try:
682            self.VTKObject.PassData(other)
683        except TypeError:
684            self.VTKObject.PassData(other.VTKObject)
685
686    def append(self, narray, name):
687        """Appends a new array to the dataset attributes."""
688        if narray is NoneArray:
689            # if NoneArray, nothing to do.
690            return
691
692        if self.Association == ArrayAssociation.POINT:
693            arrLength = self.DataSet.GetNumberOfPoints()
694        elif self.Association == ArrayAssociation.CELL:
695            arrLength = self.DataSet.GetNumberOfCells()
696        else:
697            if not isinstance(narray, numpy.ndarray):
698                arrLength = 1
699            else:
700                arrLength = narray.shape[0]
701
702        # Fixup input array length:
703        if not isinstance(narray, numpy.ndarray) or numpy.ndim(narray) == 0: # Scalar input
704            tmparray = numpy.empty(arrLength, dtype=narray.dtype)
705            tmparray.fill(narray)
706            narray = tmparray
707        elif narray.shape[0] != arrLength: # Vector input
708            components = 1
709            for l in narray.shape:
710                components *= l
711            tmparray = numpy.empty((arrLength, components), dtype=narray.dtype)
712            tmparray[:] = narray.flatten()
713            narray = tmparray
714
715        shape = narray.shape
716
717        if len(shape) == 3:
718            # Array of matrices. We need to make sure the order  in memory is right.
719            # If column order (c order), transpose. VTK wants row order (fortran
720            # order). The deep copy later will make sure that the array is contiguous.
721            # If row order but not contiguous, transpose so that the deep copy below
722            # does not happen.
723            size = narray.dtype.itemsize
724            if (narray.strides[1]/size == 3 and narray.strides[2]/size == 1) or \
725                (narray.strides[1]/size == 1 and narray.strides[2]/size == 3 and \
726                 not narray.flags.contiguous):
727                narray  = narray.transpose(0, 2, 1)
728
729        # If array is not contiguous, make a deep copy that is contiguous
730        if not narray.flags.contiguous:
731            narray = numpy.ascontiguousarray(narray)
732
733        # Flatten array of matrices to array of vectors
734        if len(shape) == 3:
735            narray = narray.reshape(shape[0], shape[1]*shape[2])
736
737        # this handle the case when an input array is directly appended on the
738        # output. We want to make sure that the array added to the output is not
739        # referring to the input dataset.
740        copy = VTKArray(narray)
741        try:
742            copy.VTKObject = narray.VTKObject
743        except AttributeError: pass
744        arr = numpyTovtkDataArray(copy, name)
745        self.VTKObject.AddArray(arr)
746
747
748class CompositeDataSetAttributes():
749    """This is a python friendly wrapper for vtkDataSetAttributes for composite
750    datsets. Since composite datasets themselves don't have attribute data, but
751    the attribute data is associated with the leaf nodes in the composite
752    dataset, this class simulates a DataSetAttributes interface by taking a
753    union of DataSetAttributes associated with all leaf nodes."""
754
755    def __init__(self, dataset, association):
756        # import weakref
757        # self.DataSet = weakref.ref(dataset)
758        self.DataSet = dataset
759        self.Association = association
760        self.ArrayNames = []
761        self.Arrays = {}
762
763        # build the set of arrays available in the composite dataset. Since
764        # composite datasets can have partial arrays, we need to iterate over
765        # all non-null blocks in the dataset.
766        self.__determine_arraynames()
767
768    def __determine_arraynames(self):
769        array_set = set()
770        array_list = []
771        for dataset in self.DataSet:
772            dsa = dataset.GetAttributes(self.Association)
773            for array_name in dsa.keys():
774                if array_name not in array_set:
775                    array_set.add(array_name)
776                    array_list.append(array_name)
777        self.ArrayNames = array_list
778
779    def keys(self):
780        """Returns the names of the arrays as a list."""
781        return self.ArrayNames
782
783    def __getitem__(self, idx):
784        """Implements the [] operator. Accepts an array name."""
785        return self.GetArray(idx)
786
787    def append(self, narray, name):
788        """Appends a new array to the composite dataset attributes."""
789        if narray is NoneArray:
790            # if NoneArray, nothing to do.
791            return
792
793        added = False
794        if not isinstance(narray, VTKCompositeDataArray): # Scalar input
795            for ds in self.DataSet:
796                ds.GetAttributes(self.Association).append(narray, name)
797                added = True
798            if added:
799                self.ArrayNames.append(name)
800                # don't add the narray since it's a scalar. GetArray() will create a
801                # VTKCompositeArray on-demand.
802        else:
803            for ds, array in izip(self.DataSet, narray.Arrays):
804                if array is not None:
805                    ds.GetAttributes(self.Association).append(array, name)
806                    added = True
807            if added:
808                self.ArrayNames.append(name)
809                self.Arrays[name] = weakref.ref(narray)
810
811    def GetArray(self, idx):
812        """Given a name, returns a VTKCompositeArray."""
813        arrayname = idx
814        if arrayname not in self.ArrayNames:
815            return NoneArray
816        if arrayname not in self.Arrays or self.Arrays[arrayname]() is None:
817            array = VTKCompositeDataArray(
818                dataset = self.DataSet, name = arrayname, association = self.Association)
819            self.Arrays[arrayname] = weakref.ref(array)
820        else:
821            array = self.Arrays[arrayname]()
822        return array
823
824    def PassData(self, other):
825        """Emulate PassData for composite datasets."""
826        for this,that in zip(self.DataSet, other.DataSet):
827            for assoc in [ArrayAssociation.POINT, ArrayAssociation.CELL]:
828                this.GetAttributes(assoc).PassData(that.GetAttributes(assoc))
829
830class CompositeDataIterator(object):
831    """Wrapper for a vtkCompositeDataIterator class to satisfy
832       the python iterator protocol. This iterator iterates
833       over non-empty leaf nodes. To iterate over empty or
834       non-leaf nodes, use the vtkCompositeDataIterator directly.
835       """
836
837    def __init__(self, cds):
838        self.Iterator = cds.NewIterator()
839        if self.Iterator:
840            self.Iterator.UnRegister(None)
841            self.Iterator.GoToFirstItem()
842
843    def __iter__(self):
844        return self
845
846    def __next__(self):
847        if not self.Iterator:
848            raise StopIteration
849
850        if self.Iterator.IsDoneWithTraversal():
851            raise StopIteration
852        retVal = self.Iterator.GetCurrentDataObject()
853        self.Iterator.GoToNextItem()
854        return WrapDataObject(retVal)
855
856    def next(self):
857        return self.__next__()
858
859    def __getattr__(self, name):
860        """Returns attributes from the vtkCompositeDataIterator."""
861        return getattr(self.Iterator, name)
862
863class MultiCompositeDataIterator(CompositeDataIterator):
864    """Iterator that can be used to iterate over multiple
865    composite datasets together. This iterator works only
866    with arrays that were copied from an original using
867    CopyStructured. The most common use case is to use
868    CopyStructure, then iterate over input and output together
869    while creating output datasets from corresponding input
870    datasets."""
871    def __init__(self, cds):
872        CompositeDataIterator.__init__(self, cds[0])
873        self.Datasets = cds
874
875    def __next__(self):
876        if not self.Iterator:
877            raise StopIteration
878
879        if self.Iterator.IsDoneWithTraversal():
880            raise StopIteration
881        retVal = []
882        retVal.append(WrapDataObject(self.Iterator.GetCurrentDataObject()))
883        if len(self.Datasets) > 1:
884            for cd in self.Datasets[1:]:
885                retVal.append(WrapDataObject(cd.GetDataSet(self.Iterator)))
886        self.Iterator.GoToNextItem()
887        return retVal
888
889    def next(self):
890        return self.__next__()
891
892class DataObject(VTKObjectWrapper):
893    """A wrapper for vtkDataObject that makes it easier to access FielData
894    arrays as VTKArrays
895    """
896
897    def GetAttributes(self, type):
898        """Returns the attributes specified by the type as a DataSetAttributes
899         instance."""
900        if type == ArrayAssociation.FIELD:
901            return DataSetAttributes(self.VTKObject.GetFieldData(), self, type)
902        return DataSetAttributes(self.VTKObject.GetAttributes(type), self, type)
903
904    def GetFieldData(self):
905        "Returns the field data as a DataSetAttributes instance."
906        return DataSetAttributes(self.VTKObject.GetFieldData(), self, ArrayAssociation.FIELD)
907
908    FieldData = property(GetFieldData, None, None, "This property returns the field data of a data object.")
909
910class Table(DataObject):
911    """A wrapper for vtkFielData that makes it easier to access RowData array as
912    VTKArrays
913    """
914    def GetRowData(self):
915        "Returns the row data as a DataSetAttributes instance."
916        return self.GetAttributes(ArrayAssociation.ROW)
917
918    RowData = property(GetRowData, None, None, "This property returns the row data of the table.")
919
920class CompositeDataSet(DataObject):
921    """A wrapper for vtkCompositeData and subclasses that makes it easier
922    to access Point/Cell/Field data as VTKCompositeDataArrays. It also
923    provides a Python type iterator."""
924
925    def __init__(self, vtkobject):
926        DataObject.__init__(self, vtkobject)
927        self._PointData = None
928        self._CellData = None
929        self._FieldData = None
930        self._Points = None
931
932    def __iter__(self):
933        "Creates an iterator for the contained datasets."
934        return CompositeDataIterator(self)
935
936    def GetNumberOfElements(self, assoc):
937        """Returns the total number of cells or points depending
938        on the value of assoc which can be ArrayAssociation.POINT or
939        ArrayAssociation.CELL."""
940        result = 0
941        for dataset in self:
942            result += dataset.GetNumberOfElements(assoc)
943        return int(result)
944
945    def GetNumberOfPoints(self):
946        """Returns the total number of points of all datasets
947        in the composite dataset. Note that this traverses the
948        whole composite dataset every time and should not be
949        called repeatedly for large composite datasets."""
950        return self.GetNumberOfElements(ArrayAssociation.POINT)
951
952    def GetNumberOfCells(self):
953        """Returns the total number of cells of all datasets
954        in the composite dataset. Note that this traverses the
955        whole composite dataset every time and should not be
956        called repeatedly for large composite datasets."""
957        return self.GetNumberOfElements(ArrayAssociation.CELL)
958
959    def GetAttributes(self, type):
960        """Returns the attributes specified by the type as a
961        CompositeDataSetAttributes instance."""
962        return CompositeDataSetAttributes(self, type)
963
964    def GetPointData(self):
965        "Returns the point data as a DataSetAttributes instance."
966        if self._PointData is None or self._PointData() is None:
967            pdata = self.GetAttributes(ArrayAssociation.POINT)
968            self._PointData = weakref.ref(pdata)
969        return self._PointData()
970
971    def GetCellData(self):
972        "Returns the cell data as a DataSetAttributes instance."
973        if self._CellData is None or self._CellData() is None:
974            cdata = self.GetAttributes(ArrayAssociation.CELL)
975            self._CellData = weakref.ref(cdata)
976        return self._CellData()
977
978    def GetFieldData(self):
979        "Returns the field data as a DataSetAttributes instance."
980        if self._FieldData is None or self._FieldData() is None:
981            fdata = self.GetAttributes(ArrayAssociation.FIELD)
982            self._FieldData = weakref.ref(fdata)
983        return self._FieldData()
984
985    def GetPoints(self):
986        "Returns the points as a VTKCompositeDataArray instance."
987        if self._Points is None or self._Points() is None:
988            pts = []
989            for ds in self:
990                try:
991                    _pts = ds.Points
992                except AttributeError:
993                    _pts = None
994
995                if _pts is None:
996                    pts.append(NoneArray)
997                else:
998                    pts.append(_pts)
999            if len(pts) == 0 or all([a is NoneArray for a in pts]):
1000                cpts = NoneArray
1001            else:
1002                cpts = VTKCompositeDataArray(pts, dataset=self)
1003            self._Points = weakref.ref(cpts)
1004        return self._Points()
1005
1006    PointData = property(GetPointData, None, None, "This property returns the point data of the dataset.")
1007    CellData = property(GetCellData, None, None, "This property returns the cell data of a dataset.")
1008    FieldData = property(GetFieldData, None, None, "This property returns the field data of a dataset.")
1009    Points = property(GetPoints, None, None, "This property returns the points of the dataset.")
1010
1011class DataSet(DataObject):
1012    """This is a python friendly wrapper of a vtkDataSet that defines
1013    a few useful properties."""
1014
1015    def GetPointData(self):
1016        "Returns the point data as a DataSetAttributes instance."
1017        return self.GetAttributes(ArrayAssociation.POINT)
1018
1019    def GetCellData(self):
1020        "Returns the cell data as a DataSetAttributes instance."
1021        return self.GetAttributes(ArrayAssociation.CELL)
1022
1023    PointData = property(GetPointData, None, None, "This property returns the point data of the dataset.")
1024    CellData = property(GetCellData, None, None, "This property returns the cell data of a dataset.")
1025
1026class PointSet(DataSet):
1027    """This is a python friendly wrapper of a vtkPointSet that defines
1028    a few useful properties."""
1029    def GetPoints(self):
1030        """Returns the points as a VTKArray instance. Returns None if the
1031        dataset has implicit points."""
1032        if not self.VTKObject.GetPoints():
1033            return None
1034        array = vtkDataArrayToVTKArray(
1035            self.VTKObject.GetPoints().GetData(), self)
1036        array.Association = ArrayAssociation.POINT
1037        return array
1038
1039    def SetPoints(self, pts):
1040        """Given a VTKArray instance, sets the points of the dataset."""
1041        from ..vtkCommonCore import vtkPoints
1042        if isinstance(pts, vtkPoints):
1043            p = pts
1044        else:
1045            pts = numpyTovtkDataArray(pts)
1046            p = vtkPoints()
1047            p.SetData(pts)
1048        self.VTKObject.SetPoints(p)
1049
1050    Points = property(GetPoints, SetPoints, None, "This property returns the point coordinates of dataset.")
1051
1052class PolyData(PointSet):
1053    """This is a python friendly wrapper of a vtkPolyData that defines
1054    a few useful properties."""
1055
1056    def GetPolygons(self):
1057        """Returns the polys as a VTKArray instance."""
1058        if not self.VTKObject.GetPolys():
1059            return None
1060        return vtkDataArrayToVTKArray(
1061            self.VTKObject.GetPolys().GetData(), self)
1062
1063    Polygons = property(GetPolygons, None, None, "This property returns the connectivity of polygons.")
1064
1065class UnstructuredGrid(PointSet):
1066    """This is a python friendly wrapper of a vtkUnstructuredGrid that defines
1067    a few useful properties."""
1068
1069    def GetCellTypes(self):
1070        """Returns the cell types as a VTKArray instance."""
1071        if not self.VTKObject.GetCellTypesArray():
1072            return None
1073        return vtkDataArrayToVTKArray(
1074            self.VTKObject.GetCellTypesArray(), self)
1075
1076    def GetCellLocations(self):
1077        """Returns the cell locations as a VTKArray instance."""
1078        if not self.VTKObject.GetCellLocationsArray():
1079            return None
1080        return vtkDataArrayToVTKArray(
1081            self.VTKObject.GetCellLocationsArray(), self)
1082
1083    def GetCells(self):
1084        """Returns the cells as a VTKArray instance."""
1085        if not self.VTKObject.GetCells():
1086            return None
1087        return vtkDataArrayToVTKArray(
1088            self.VTKObject.GetCells().GetData(), self)
1089
1090    def SetCells(self, cellTypes, cellLocations, cells):
1091        """Given cellTypes, cellLocations, cells as VTKArrays,
1092        populates the unstructured grid data structures."""
1093        from ..util.vtkConstants import VTK_ID_TYPE
1094        from ..vtkCommonDataModel import vtkCellArray
1095        cellTypes = numpyTovtkDataArray(cellTypes)
1096        cellLocations = numpyTovtkDataArray(cellLocations, array_type=VTK_ID_TYPE)
1097        cells = numpyTovtkDataArray(cells, array_type=VTK_ID_TYPE)
1098        ca = vtkCellArray()
1099        ca.SetCells(cellTypes.GetNumberOfTuples(), cells)
1100        self.VTKObject.SetCells(cellTypes, cellLocations, ca)
1101
1102    CellTypes = property(GetCellTypes, None, None, "This property returns the types of cells.")
1103    CellLocations = property(GetCellLocations, None, None, "This property returns the locations of cells.")
1104    Cells = property(GetCells, None, None, "This property returns the connectivity of cells.")
1105
1106class Graph(DataObject):
1107    """This is a python friendly wrapper of a vtkGraph that defines
1108    a few useful properties."""
1109
1110    def GetVertexData(self):
1111        "Returns the vertex data as a DataSetAttributes instance."
1112        return self.GetAttributes(ArrayAssociation.VERTEX)
1113
1114    def GetEdgeData(self):
1115        "Returns the edge data as a DataSetAttributes instance."
1116        return self.GetAttributes(ArrayAssociation.EDGE)
1117
1118    VertexData = property(GetVertexData, None, None, "This property returns the vertex data of the graph.")
1119    EdgeData = property(GetEdgeData, None, None, "This property returns the edge data of the graph.")
1120
1121class Molecule(DataObject):
1122    """This is a python friendly wrapper of a vtkMolecule that defines
1123    a few useful properties."""
1124    def GetAtomData(self):
1125        "Returns the atom data as a DataSetAttributes instance."
1126        return self.GetVertexData()
1127
1128    def GetBondData(self):
1129        "Returns the bond data as a DataSetAttributes instance."
1130        return self.GetEdgeData()
1131
1132    AtomData = property(GetAtomData, None, None, "This property returns the atom data of the molecule.")
1133    BondData = property(GetBondData, None, None, "This property returns the bond data of the molecule.")
1134
1135def WrapDataObject(ds):
1136    """Returns a Numpy friendly wrapper of a vtkDataObject."""
1137    if ds.IsA("vtkPolyData"):
1138        return PolyData(ds)
1139    elif ds.IsA("vtkUnstructuredGrid"):
1140        return UnstructuredGrid(ds)
1141    elif ds.IsA("vtkPointSet"):
1142        return PointSet(ds)
1143    elif ds.IsA("vtkDataSet"):
1144        return DataSet(ds)
1145    elif ds.IsA("vtkCompositeDataSet"):
1146        return CompositeDataSet(ds)
1147    elif ds.IsA("vtkTable"):
1148        return Table(ds)
1149    elif ds.IsA("vtkMolecule"):
1150        return Molecule(ds)
1151    elif ds.IsA("vtkGraph"):
1152        return Table(ds)
1153    elif ds.IsA("vtkDataObject"):
1154        return DataObject(ds)
1155