1'''
2Generic functions for handling spectral images.
3'''
4
5from __future__ import absolute_import, division, print_function, unicode_literals
6
7import numbers
8import numpy as np
9
10from .spectral import BandInfo
11
12class Image(object):
13    '''spectral.Image is the common base class for spectral image objects.'''
14
15    def __init__(self, params, metadata=None):
16        self.bands = BandInfo()
17        self.set_params(params, metadata)
18
19    def set_params(self, params, metadata):
20        try:
21            self.nbands = params.nbands
22            self.nrows = params.nrows
23            self.ncols = params.ncols
24            self.dtype = params.dtype
25
26            if not metadata:
27                self.metadata = {}
28            else:
29                self.metadata = metadata
30        except:
31            raise
32
33    def params(self):
34        '''Return an object containing the SpyFile parameters.'''
35
36        class P:
37            pass
38        p = P()
39
40        p.nbands = self.nbands
41        p.nrows = self.nrows
42        p.ncols = self.ncols
43        p.metadata = self.metadata
44        p.dtype = self.dtype
45
46        return p
47
48    def __repr__(self):
49        return self.__str__()
50
51
52class ImageArray(np.ndarray, Image):
53    '''ImageArray is an interface to an image loaded entirely into memory.
54    ImageArray objects are returned by :meth:`spectral.SpyFile.load`.
55    This class inherits from both numpy.ndarray and Image, providing the
56    interfaces of both classes.
57    '''
58
59    format = 'f'        # Use 4-byte floats for data arrays
60
61    def __new__(subclass, data, spyfile):
62        obj = np.asarray(data).view(subclass)
63        ImageArray.__init__(obj, data, spyfile)
64        return obj
65
66    def __init__(self, data, spyfile):
67        # Add param data to Image initializer
68        params = spyfile.params()
69        params.dtype = data.dtype
70        params.swap = 0
71
72        Image.__init__(self, params, spyfile.metadata)
73        self.bands = spyfile.bands
74        self.filename = spyfile.filename
75        self.interleave = 2 # bip
76
77    def __repr__(self):
78        lst = np.array2string(np.asarray(self), prefix="ImageArray(")
79        return "{}({}, dtype={})".format('ImageArray', lst, self.dtype.name)
80
81    def __getitem__(self, args):
82        # Duplicate the indexing behavior of SpyFile.  If args is iterable
83        # with length greater than one, and if not all of the args are
84        # scalars, then the scalars need to be replaced with slices.
85        try:
86            iterator = iter(args)
87        except TypeError:
88            if isinstance(args, numbers.Number):
89                if args == -1:
90                    updated_args = slice(args, None)
91                else:
92                    updated_args = slice(args, args+1)
93            else:
94                updated_args = args
95            return self._parent_getitem(updated_args)
96
97        keep_original_args = True
98        updated_args = []
99        for arg in iterator:
100            if isinstance(arg, numbers.Number):
101                if arg == -1:
102                    updated_args.append(slice(arg, None))
103                else:
104                    updated_args.append(slice(arg, arg+1))
105            elif isinstance(arg, np.bool_):
106                updated_args.append(arg)
107            else:
108                updated_args.append(arg)
109                keep_original_args = False
110
111        if keep_original_args:
112            updated_args = args
113        else:
114            updated_args = tuple(updated_args)
115
116        return self._parent_getitem(updated_args)
117
118    def _parent_getitem(self, args):
119        return np.ndarray.__getitem__(self, args)
120
121    def read_band(self, i):
122        '''
123        For compatibility with SpyFile objects. Returns arr[:,:,i].squeeze()
124        '''
125        return np.asarray(self[:, :, i].squeeze())
126
127    def read_bands(self, bands):
128        '''For SpyFile compatibility. Equivlalent to arr.take(bands, 2)'''
129        return np.asarray(self.take(bands, 2))
130
131    def read_pixel(self, row, col):
132        '''For SpyFile compatibility. Equivlalent to arr[row, col]'''
133        return np.asarray(self[row, col])
134
135    def read_subregion(self, row_bounds, col_bounds, bands=None):
136        '''
137        For SpyFile compatibility.
138
139        Equivalent to arr[slice(*row_bounds), slice(*col_bounds), bands],
140        selecting all bands if none are specified.
141        '''
142        if bands:
143            return np.asarray(self[slice(*row_bounds),
144                                   slice(*col_bounds),
145                                   bands])
146        else:
147            return np.asarray(self[slice(*row_bounds),
148                                   slice(*col_bounds)])
149
150    def read_subimage(self, rows, cols, bands=None):
151        '''
152        For SpyFile compatibility.
153
154        Equivalent to arr[rows][:, cols][:, :, bands], selecting all bands if
155        none are specified.
156        '''
157        if bands:
158            return np.asarray(self[rows][:, cols][:, :, bands])
159        else:
160            return np.asarray(self[rows][:, cols])
161
162    def read_datum(self, i, j, k):
163        '''For SpyFile compatibility. Equivlalent to arr[i, j, k]'''
164        return np.asscalar(self[i, j, k])
165
166    def load(self):
167        '''For compatibility with SpyFile objects. Returns self'''
168        return self
169
170    def asarray(self, writable=False):
171        '''Returns an object with a standard numpy array interface.
172
173        The return value is the same as calling `numpy.asarray`, except
174        that the array is not writable by default to match the behavior
175        of `SpyFile.asarray`.
176
177        This function is for compatibility with SpyFile objects.
178
179        Keyword Arguments:
180
181            `writable` (bool, default False):
182
183                If `writable` is True, modifying values in the returned
184                array will result in corresponding modification to the
185                ImageArray object.
186        '''
187        arr = np.asarray(self)
188        if not writable:
189            arr.setflags(write=False)
190        return arr
191
192    def info(self):
193        s = '\t# Rows:         %6d\n' % (self.nrows)
194        s += '\t# Samples:      %6d\n' % (self.ncols)
195        s += '\t# Bands:        %6d\n' % (self.shape[2])
196
197        s += '\tData format:  %8s' % self.dtype.name
198        return s
199
200    def __array_wrap__(self, out_arr, context=None):
201        # The ndarray __array_wrap__ causes ufunc results to be of type
202        # ImageArray.  Instead, return a plain ndarray.
203        return out_arr
204
205    # Some methods do not call __array_wrap__ and will return an ImageArray.
206    # Currently, these need to be overridden individually or with
207    # __getattribute__ magic.
208
209    def __getattribute__(self, name):
210        if ((name in np.ndarray.__dict__) and
211            (name not in ImageArray.__dict__)):
212            return getattr(np.asarray(self), name)
213
214        return super(ImageArray, self).__getattribute__(name)
215
216