1"""
2A buffered iterator for big arrays.
3
4This module solves the problem of iterating over a big file-based array
5without having to read it into memory. The `Arrayterator` class wraps
6an array object, and when iterated it will return sub-arrays with at most
7a user-specified number of elements.
8
9"""
10from operator import mul
11from functools import reduce
12
13__all__ = ['Arrayterator']
14
15
16class Arrayterator:
17    """
18    Buffered iterator for big arrays.
19
20    `Arrayterator` creates a buffered iterator for reading big arrays in small
21    contiguous blocks. The class is useful for objects stored in the
22    file system. It allows iteration over the object *without* reading
23    everything in memory; instead, small blocks are read and iterated over.
24
25    `Arrayterator` can be used with any object that supports multidimensional
26    slices. This includes NumPy arrays, but also variables from
27    Scientific.IO.NetCDF or pynetcdf for example.
28
29    Parameters
30    ----------
31    var : array_like
32        The object to iterate over.
33    buf_size : int, optional
34        The buffer size. If `buf_size` is supplied, the maximum amount of
35        data that will be read into memory is `buf_size` elements.
36        Default is None, which will read as many element as possible
37        into memory.
38
39    Attributes
40    ----------
41    var
42    buf_size
43    start
44    stop
45    step
46    shape
47    flat
48
49    See Also
50    --------
51    ndenumerate : Multidimensional array iterator.
52    flatiter : Flat array iterator.
53    memmap : Create a memory-map to an array stored in a binary file on disk.
54
55    Notes
56    -----
57    The algorithm works by first finding a "running dimension", along which
58    the blocks will be extracted. Given an array of dimensions
59    ``(d1, d2, ..., dn)``, e.g. if `buf_size` is smaller than ``d1``, the
60    first dimension will be used. If, on the other hand,
61    ``d1 < buf_size < d1*d2`` the second dimension will be used, and so on.
62    Blocks are extracted along this dimension, and when the last block is
63    returned the process continues from the next dimension, until all
64    elements have been read.
65
66    Examples
67    --------
68    >>> a = np.arange(3 * 4 * 5 * 6).reshape(3, 4, 5, 6)
69    >>> a_itor = np.lib.Arrayterator(a, 2)
70    >>> a_itor.shape
71    (3, 4, 5, 6)
72
73    Now we can iterate over ``a_itor``, and it will return arrays of size
74    two. Since `buf_size` was smaller than any dimension, the first
75    dimension will be iterated over first:
76
77    >>> for subarr in a_itor:
78    ...     if not subarr.all():
79    ...         print(subarr, subarr.shape) # doctest: +SKIP
80    >>> # [[[[0 1]]]] (1, 1, 1, 2)
81
82    """
83
84    def __init__(self, var, buf_size=None):
85        self.var = var
86        self.buf_size = buf_size
87
88        self.start = [0 for dim in var.shape]
89        self.stop = [dim for dim in var.shape]
90        self.step = [1 for dim in var.shape]
91
92    def __getattr__(self, attr):
93        return getattr(self.var, attr)
94
95    def __getitem__(self, index):
96        """
97        Return a new arrayterator.
98
99        """
100        # Fix index, handling ellipsis and incomplete slices.
101        if not isinstance(index, tuple):
102            index = (index,)
103        fixed = []
104        length, dims = len(index), self.ndim
105        for slice_ in index:
106            if slice_ is Ellipsis:
107                fixed.extend([slice(None)] * (dims-length+1))
108                length = len(fixed)
109            elif isinstance(slice_, int):
110                fixed.append(slice(slice_, slice_+1, 1))
111            else:
112                fixed.append(slice_)
113        index = tuple(fixed)
114        if len(index) < dims:
115            index += (slice(None),) * (dims-len(index))
116
117        # Return a new arrayterator object.
118        out = self.__class__(self.var, self.buf_size)
119        for i, (start, stop, step, slice_) in enumerate(
120                zip(self.start, self.stop, self.step, index)):
121            out.start[i] = start + (slice_.start or 0)
122            out.step[i] = step * (slice_.step or 1)
123            out.stop[i] = start + (slice_.stop or stop-start)
124            out.stop[i] = min(stop, out.stop[i])
125        return out
126
127    def __array__(self):
128        """
129        Return corresponding data.
130
131        """
132        slice_ = tuple(slice(*t) for t in zip(
133                self.start, self.stop, self.step))
134        return self.var[slice_]
135
136    @property
137    def flat(self):
138        """
139        A 1-D flat iterator for Arrayterator objects.
140
141        This iterator returns elements of the array to be iterated over in
142        `Arrayterator` one by one. It is similar to `flatiter`.
143
144        See Also
145        --------
146        Arrayterator
147        flatiter
148
149        Examples
150        --------
151        >>> a = np.arange(3 * 4 * 5 * 6).reshape(3, 4, 5, 6)
152        >>> a_itor = np.lib.Arrayterator(a, 2)
153
154        >>> for subarr in a_itor.flat:
155        ...     if not subarr:
156        ...         print(subarr, type(subarr))
157        ...
158        0 <class 'numpy.int64'>
159
160        """
161        for block in self:
162            yield from block.flat
163
164    @property
165    def shape(self):
166        """
167        The shape of the array to be iterated over.
168
169        For an example, see `Arrayterator`.
170
171        """
172        return tuple(((stop-start-1)//step+1) for start, stop, step in
173                zip(self.start, self.stop, self.step))
174
175    def __iter__(self):
176        # Skip arrays with degenerate dimensions
177        if [dim for dim in self.shape if dim <= 0]:
178            return
179
180        start = self.start[:]
181        stop = self.stop[:]
182        step = self.step[:]
183        ndims = self.var.ndim
184
185        while True:
186            count = self.buf_size or reduce(mul, self.shape)
187
188            # iterate over each dimension, looking for the
189            # running dimension (ie, the dimension along which
190            # the blocks will be built from)
191            rundim = 0
192            for i in range(ndims-1, -1, -1):
193                # if count is zero we ran out of elements to read
194                # along higher dimensions, so we read only a single position
195                if count == 0:
196                    stop[i] = start[i]+1
197                elif count <= self.shape[i]:
198                    # limit along this dimension
199                    stop[i] = start[i] + count*step[i]
200                    rundim = i
201                else:
202                    # read everything along this dimension
203                    stop[i] = self.stop[i]
204                stop[i] = min(self.stop[i], stop[i])
205                count = count//self.shape[i]
206
207            # yield a block
208            slice_ = tuple(slice(*t) for t in zip(start, stop, step))
209            yield self.var[slice_]
210
211            # Update start position, taking care of overflow to
212            # other dimensions
213            start[rundim] = stop[rundim]  # start where we stopped
214            for i in range(ndims-1, 0, -1):
215                if start[i] >= self.stop[i]:
216                    start[i] = self.start[i]
217                    start[i-1] += self.step[i-1]
218            if start[0] >= self.stop[0]:
219                return
220