1 // Copyright 2014-2016 bluss and ndarray developers.
2 //
3 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6 // option. This file may not be copied, modified, or distributed
7 // except according to those terms.
8 use super::Dimension;
9 use crate::dimension::IntoDimension;
10 use crate::zip::Offset;
11 use crate::split_at::SplitAt;
12 use crate::Axis;
13 use crate::Layout;
14 use crate::NdProducer;
15 use crate::{ArrayBase, Data};
16 
17 /// An iterator over the indexes of an array shape.
18 ///
19 /// Iterator element type is `D`.
20 #[derive(Clone)]
21 pub struct IndicesIter<D> {
22     dim: D,
23     index: Option<D>,
24 }
25 
26 /// Create an iterable of the array shape `shape`.
27 ///
28 /// *Note:* prefer higher order methods, arithmetic operations and
29 /// non-indexed iteration before using indices.
indices<E>(shape: E) -> Indices<E::Dim> where E: IntoDimension,30 pub fn indices<E>(shape: E) -> Indices<E::Dim>
31 where
32     E: IntoDimension,
33 {
34     let dim = shape.into_dimension();
35     Indices {
36         start: E::Dim::zeros(dim.ndim()),
37         dim,
38     }
39 }
40 
41 /// Return an iterable of the indices of the passed-in array.
42 ///
43 /// *Note:* prefer higher order methods, arithmetic operations and
44 /// non-indexed iteration before using indices.
indices_of<S, D>(array: &ArrayBase<S, D>) -> Indices<D> where S: Data, D: Dimension,45 pub fn indices_of<S, D>(array: &ArrayBase<S, D>) -> Indices<D>
46 where
47     S: Data,
48     D: Dimension,
49 {
50     indices(array.dim())
51 }
52 
53 impl<D> Iterator for IndicesIter<D>
54 where
55     D: Dimension,
56 {
57     type Item = D::Pattern;
58     #[inline]
next(&mut self) -> Option<Self::Item>59     fn next(&mut self) -> Option<Self::Item> {
60         let index = match self.index {
61             None => return None,
62             Some(ref ix) => ix.clone(),
63         };
64         self.index = self.dim.next_for(index.clone());
65         Some(index.into_pattern())
66     }
67 
size_hint(&self) -> (usize, Option<usize>)68     fn size_hint(&self) -> (usize, Option<usize>) {
69         let l = match self.index {
70             None => 0,
71             Some(ref ix) => {
72                 let gone = self
73                     .dim
74                     .default_strides()
75                     .slice()
76                     .iter()
77                     .zip(ix.slice().iter())
78                     .fold(0, |s, (&a, &b)| s + a as usize * b as usize);
79                 self.dim.size() - gone
80             }
81         };
82         (l, Some(l))
83     }
84 
fold<B, F>(self, init: B, mut f: F) -> B where F: FnMut(B, D::Pattern) -> B,85     fn fold<B, F>(self, init: B, mut f: F) -> B
86     where
87         F: FnMut(B, D::Pattern) -> B,
88     {
89         let IndicesIter { mut index, dim } = self;
90         let ndim = dim.ndim();
91         if ndim == 0 {
92             return match index {
93                 Some(ix) => f(init, ix.into_pattern()),
94                 None => init,
95             };
96         }
97         let inner_axis = ndim - 1;
98         let inner_len = dim[inner_axis];
99         let mut acc = init;
100         while let Some(mut ix) = index {
101             // unroll innermost axis
102             for i in ix[inner_axis]..inner_len {
103                 ix[inner_axis] = i;
104                 acc = f(acc, ix.clone().into_pattern());
105             }
106             index = dim.next_for(ix);
107         }
108         acc
109     }
110 }
111 
112 impl<D> ExactSizeIterator for IndicesIter<D> where D: Dimension {}
113 
114 impl<D> IntoIterator for Indices<D>
115 where
116     D: Dimension,
117 {
118     type Item = D::Pattern;
119     type IntoIter = IndicesIter<D>;
into_iter(self) -> Self::IntoIter120     fn into_iter(self) -> Self::IntoIter {
121         let sz = self.dim.size();
122         let index = if sz != 0 { Some(self.start) } else { None };
123         IndicesIter {
124             index,
125             dim: self.dim,
126         }
127     }
128 }
129 
130 /// Indices producer and iterable.
131 ///
132 /// `Indices` is an `NdProducer` that produces the indices of an array shape.
133 #[derive(Copy, Clone, Debug)]
134 pub struct Indices<D>
135 where
136     D: Dimension,
137 {
138     start: D,
139     dim: D,
140 }
141 
142 #[derive(Copy, Clone, Debug)]
143 pub struct IndexPtr<D> {
144     index: D,
145 }
146 
147 impl<D> Offset for IndexPtr<D>
148 where
149     D: Dimension + Copy,
150 {
151     // stride: The axis to increment
152     type Stride = usize;
153 
stride_offset(mut self, stride: Self::Stride, index: usize) -> Self154     unsafe fn stride_offset(mut self, stride: Self::Stride, index: usize) -> Self {
155         self.index[stride] += index;
156         self
157     }
158     private_impl! {}
159 }
160 
161 // How the NdProducer for Indices works.
162 //
163 // NdProducer allows for raw pointers (Ptr), strides (Stride) and the produced
164 // item (Item).
165 //
166 // Instead of Ptr, there is `IndexPtr<D>` which is an index value, like [0, 0, 0]
167 // for the three dimensional case.
168 //
169 // The stride is simply which axis is currently being incremented. The stride for axis 1, is 1.
170 //
171 // .stride_offset(stride, index) simply computes the new index along that axis, for example:
172 // [0, 0, 0].stride_offset(1, 10) => [0, 10, 0]  axis 1 is incremented by 10.
173 //
174 // .as_ref() converts the Ptr value to an Item. For example [0, 10, 0] => (0, 10, 0)
175 impl<D: Dimension + Copy> NdProducer for Indices<D> {
176     type Item = D::Pattern;
177     type Dim = D;
178     type Ptr = IndexPtr<D>;
179     type Stride = usize;
180 
181     private_impl! {}
182 
183     #[doc(hidden)]
raw_dim(&self) -> Self::Dim184     fn raw_dim(&self) -> Self::Dim {
185         self.dim
186     }
187 
188     #[doc(hidden)]
equal_dim(&self, dim: &Self::Dim) -> bool189     fn equal_dim(&self, dim: &Self::Dim) -> bool {
190         self.dim.equal(dim)
191     }
192 
193     #[doc(hidden)]
as_ptr(&self) -> Self::Ptr194     fn as_ptr(&self) -> Self::Ptr {
195         IndexPtr { index: self.start }
196     }
197 
198     #[doc(hidden)]
layout(&self) -> Layout199     fn layout(&self) -> Layout {
200         if self.dim.ndim() <= 1 {
201             Layout::one_dimensional()
202         } else {
203             Layout::none()
204         }
205     }
206 
207     #[doc(hidden)]
as_ref(&self, ptr: Self::Ptr) -> Self::Item208     unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item {
209         ptr.index.into_pattern()
210     }
211 
212     #[doc(hidden)]
uget_ptr(&self, i: &Self::Dim) -> Self::Ptr213     unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr {
214         let mut index = *i;
215         index += &self.start;
216         IndexPtr { index }
217     }
218 
219     #[doc(hidden)]
stride_of(&self, axis: Axis) -> Self::Stride220     fn stride_of(&self, axis: Axis) -> Self::Stride {
221         axis.index()
222     }
223 
224     #[inline(always)]
contiguous_stride(&self) -> Self::Stride225     fn contiguous_stride(&self) -> Self::Stride {
226         0
227     }
228 
229     #[doc(hidden)]
split_at(self, axis: Axis, index: usize) -> (Self, Self)230     fn split_at(self, axis: Axis, index: usize) -> (Self, Self) {
231         let start_a = self.start;
232         let mut start_b = start_a;
233         let (a, b) = self.dim.split_at(axis, index);
234         start_b[axis.index()] += index;
235         (
236             Indices {
237                 start: start_a,
238                 dim: a,
239             },
240             Indices {
241                 start: start_b,
242                 dim: b,
243             },
244         )
245     }
246 }
247 
248 /// An iterator over the indexes of an array shape.
249 ///
250 /// Iterator element type is `D`.
251 #[derive(Clone)]
252 pub struct IndicesIterF<D> {
253     dim: D,
254     index: D,
255     has_remaining: bool,
256 }
257 
indices_iter_f<E>(shape: E) -> IndicesIterF<E::Dim> where E: IntoDimension,258 pub fn indices_iter_f<E>(shape: E) -> IndicesIterF<E::Dim>
259 where
260     E: IntoDimension,
261 {
262     let dim = shape.into_dimension();
263     let zero = E::Dim::zeros(dim.ndim());
264     IndicesIterF {
265         has_remaining: dim.size_checked() != Some(0),
266         index: zero,
267         dim,
268     }
269 }
270 
271 impl<D> Iterator for IndicesIterF<D>
272 where
273     D: Dimension,
274 {
275     type Item = D::Pattern;
276     #[inline]
next(&mut self) -> Option<Self::Item>277     fn next(&mut self) -> Option<Self::Item> {
278         if !self.has_remaining {
279             None
280         } else {
281             let elt = self.index.clone().into_pattern();
282             self.has_remaining = self.dim.next_for_f(&mut self.index);
283             Some(elt)
284         }
285     }
286 
size_hint(&self) -> (usize, Option<usize>)287     fn size_hint(&self) -> (usize, Option<usize>) {
288         if !self.has_remaining {
289             return (0, Some(0));
290         }
291         let gone = self
292             .dim
293             .fortran_strides()
294             .slice()
295             .iter()
296             .zip(self.index.slice().iter())
297             .fold(0, |s, (&a, &b)| s + a as usize * b as usize);
298         let l = self.dim.size() - gone;
299         (l, Some(l))
300     }
301 }
302 
303 impl<D> ExactSizeIterator for IndicesIterF<D> where D: Dimension {}
304 
305 #[cfg(test)]
306 mod tests {
307     use super::indices;
308     use super::indices_iter_f;
309 
310     #[test]
test_indices_iter_c_size_hint()311     fn test_indices_iter_c_size_hint() {
312         let dim = (3, 4);
313         let mut it = indices(dim).into_iter();
314         let mut len = dim.0 * dim.1;
315         assert_eq!(it.len(), len);
316         while let Some(_) = it.next() {
317             len -= 1;
318             assert_eq!(it.len(), len);
319         }
320         assert_eq!(len, 0);
321     }
322 
323     #[test]
test_indices_iter_c_fold()324     fn test_indices_iter_c_fold() {
325         macro_rules! run_test {
326             ($dim:expr) => {
327                 for num_consume in 0..3 {
328                     let mut it = indices($dim).into_iter();
329                     for _ in 0..num_consume {
330                         it.next();
331                     }
332                     let clone = it.clone();
333                     let len = it.len();
334                     let acc = clone.fold(0, |acc, ix| {
335                         assert_eq!(ix, it.next().unwrap());
336                         acc + 1
337                     });
338                     assert_eq!(acc, len);
339                     assert!(it.next().is_none());
340                 }
341             };
342         }
343         run_test!(());
344         run_test!((2,));
345         run_test!((2, 3));
346         run_test!((2, 0, 3));
347         run_test!((2, 3, 4));
348         run_test!((2, 3, 4, 2));
349     }
350 
351     #[test]
test_indices_iter_f_size_hint()352     fn test_indices_iter_f_size_hint() {
353         let dim = (3, 4);
354         let mut it = indices_iter_f(dim);
355         let mut len = dim.0 * dim.1;
356         assert_eq!(it.len(), len);
357         while let Some(_) = it.next() {
358             len -= 1;
359             assert_eq!(it.len(), len);
360         }
361         assert_eq!(len, 0);
362     }
363 }
364