1 // Copyright 2014-2016 bluss and ndarray developers.
2 //
3 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6 // option. This file may not be copied, modified, or distributed
7 // except according to those terms.
8 use super::Dimension;
9 use crate::dimension::IntoDimension;
10 use crate::zip::Offset;
11 use crate::split_at::SplitAt;
12 use crate::Axis;
13 use crate::Layout;
14 use crate::NdProducer;
15 use crate::{ArrayBase, Data};
16
17 /// An iterator over the indexes of an array shape.
18 ///
19 /// Iterator element type is `D`.
20 #[derive(Clone)]
21 pub struct IndicesIter<D> {
22 dim: D,
23 index: Option<D>,
24 }
25
26 /// Create an iterable of the array shape `shape`.
27 ///
28 /// *Note:* prefer higher order methods, arithmetic operations and
29 /// non-indexed iteration before using indices.
indices<E>(shape: E) -> Indices<E::Dim> where E: IntoDimension,30 pub fn indices<E>(shape: E) -> Indices<E::Dim>
31 where
32 E: IntoDimension,
33 {
34 let dim = shape.into_dimension();
35 Indices {
36 start: E::Dim::zeros(dim.ndim()),
37 dim,
38 }
39 }
40
41 /// Return an iterable of the indices of the passed-in array.
42 ///
43 /// *Note:* prefer higher order methods, arithmetic operations and
44 /// non-indexed iteration before using indices.
indices_of<S, D>(array: &ArrayBase<S, D>) -> Indices<D> where S: Data, D: Dimension,45 pub fn indices_of<S, D>(array: &ArrayBase<S, D>) -> Indices<D>
46 where
47 S: Data,
48 D: Dimension,
49 {
50 indices(array.dim())
51 }
52
53 impl<D> Iterator for IndicesIter<D>
54 where
55 D: Dimension,
56 {
57 type Item = D::Pattern;
58 #[inline]
next(&mut self) -> Option<Self::Item>59 fn next(&mut self) -> Option<Self::Item> {
60 let index = match self.index {
61 None => return None,
62 Some(ref ix) => ix.clone(),
63 };
64 self.index = self.dim.next_for(index.clone());
65 Some(index.into_pattern())
66 }
67
size_hint(&self) -> (usize, Option<usize>)68 fn size_hint(&self) -> (usize, Option<usize>) {
69 let l = match self.index {
70 None => 0,
71 Some(ref ix) => {
72 let gone = self
73 .dim
74 .default_strides()
75 .slice()
76 .iter()
77 .zip(ix.slice().iter())
78 .fold(0, |s, (&a, &b)| s + a as usize * b as usize);
79 self.dim.size() - gone
80 }
81 };
82 (l, Some(l))
83 }
84
fold<B, F>(self, init: B, mut f: F) -> B where F: FnMut(B, D::Pattern) -> B,85 fn fold<B, F>(self, init: B, mut f: F) -> B
86 where
87 F: FnMut(B, D::Pattern) -> B,
88 {
89 let IndicesIter { mut index, dim } = self;
90 let ndim = dim.ndim();
91 if ndim == 0 {
92 return match index {
93 Some(ix) => f(init, ix.into_pattern()),
94 None => init,
95 };
96 }
97 let inner_axis = ndim - 1;
98 let inner_len = dim[inner_axis];
99 let mut acc = init;
100 while let Some(mut ix) = index {
101 // unroll innermost axis
102 for i in ix[inner_axis]..inner_len {
103 ix[inner_axis] = i;
104 acc = f(acc, ix.clone().into_pattern());
105 }
106 index = dim.next_for(ix);
107 }
108 acc
109 }
110 }
111
112 impl<D> ExactSizeIterator for IndicesIter<D> where D: Dimension {}
113
114 impl<D> IntoIterator for Indices<D>
115 where
116 D: Dimension,
117 {
118 type Item = D::Pattern;
119 type IntoIter = IndicesIter<D>;
into_iter(self) -> Self::IntoIter120 fn into_iter(self) -> Self::IntoIter {
121 let sz = self.dim.size();
122 let index = if sz != 0 { Some(self.start) } else { None };
123 IndicesIter {
124 index,
125 dim: self.dim,
126 }
127 }
128 }
129
130 /// Indices producer and iterable.
131 ///
132 /// `Indices` is an `NdProducer` that produces the indices of an array shape.
133 #[derive(Copy, Clone, Debug)]
134 pub struct Indices<D>
135 where
136 D: Dimension,
137 {
138 start: D,
139 dim: D,
140 }
141
142 #[derive(Copy, Clone, Debug)]
143 pub struct IndexPtr<D> {
144 index: D,
145 }
146
147 impl<D> Offset for IndexPtr<D>
148 where
149 D: Dimension + Copy,
150 {
151 // stride: The axis to increment
152 type Stride = usize;
153
stride_offset(mut self, stride: Self::Stride, index: usize) -> Self154 unsafe fn stride_offset(mut self, stride: Self::Stride, index: usize) -> Self {
155 self.index[stride] += index;
156 self
157 }
158 private_impl! {}
159 }
160
161 // How the NdProducer for Indices works.
162 //
163 // NdProducer allows for raw pointers (Ptr), strides (Stride) and the produced
164 // item (Item).
165 //
166 // Instead of Ptr, there is `IndexPtr<D>` which is an index value, like [0, 0, 0]
167 // for the three dimensional case.
168 //
169 // The stride is simply which axis is currently being incremented. The stride for axis 1, is 1.
170 //
171 // .stride_offset(stride, index) simply computes the new index along that axis, for example:
172 // [0, 0, 0].stride_offset(1, 10) => [0, 10, 0] axis 1 is incremented by 10.
173 //
174 // .as_ref() converts the Ptr value to an Item. For example [0, 10, 0] => (0, 10, 0)
175 impl<D: Dimension + Copy> NdProducer for Indices<D> {
176 type Item = D::Pattern;
177 type Dim = D;
178 type Ptr = IndexPtr<D>;
179 type Stride = usize;
180
181 private_impl! {}
182
183 #[doc(hidden)]
raw_dim(&self) -> Self::Dim184 fn raw_dim(&self) -> Self::Dim {
185 self.dim
186 }
187
188 #[doc(hidden)]
equal_dim(&self, dim: &Self::Dim) -> bool189 fn equal_dim(&self, dim: &Self::Dim) -> bool {
190 self.dim.equal(dim)
191 }
192
193 #[doc(hidden)]
as_ptr(&self) -> Self::Ptr194 fn as_ptr(&self) -> Self::Ptr {
195 IndexPtr { index: self.start }
196 }
197
198 #[doc(hidden)]
layout(&self) -> Layout199 fn layout(&self) -> Layout {
200 if self.dim.ndim() <= 1 {
201 Layout::one_dimensional()
202 } else {
203 Layout::none()
204 }
205 }
206
207 #[doc(hidden)]
as_ref(&self, ptr: Self::Ptr) -> Self::Item208 unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item {
209 ptr.index.into_pattern()
210 }
211
212 #[doc(hidden)]
uget_ptr(&self, i: &Self::Dim) -> Self::Ptr213 unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr {
214 let mut index = *i;
215 index += &self.start;
216 IndexPtr { index }
217 }
218
219 #[doc(hidden)]
stride_of(&self, axis: Axis) -> Self::Stride220 fn stride_of(&self, axis: Axis) -> Self::Stride {
221 axis.index()
222 }
223
224 #[inline(always)]
contiguous_stride(&self) -> Self::Stride225 fn contiguous_stride(&self) -> Self::Stride {
226 0
227 }
228
229 #[doc(hidden)]
split_at(self, axis: Axis, index: usize) -> (Self, Self)230 fn split_at(self, axis: Axis, index: usize) -> (Self, Self) {
231 let start_a = self.start;
232 let mut start_b = start_a;
233 let (a, b) = self.dim.split_at(axis, index);
234 start_b[axis.index()] += index;
235 (
236 Indices {
237 start: start_a,
238 dim: a,
239 },
240 Indices {
241 start: start_b,
242 dim: b,
243 },
244 )
245 }
246 }
247
248 /// An iterator over the indexes of an array shape.
249 ///
250 /// Iterator element type is `D`.
251 #[derive(Clone)]
252 pub struct IndicesIterF<D> {
253 dim: D,
254 index: D,
255 has_remaining: bool,
256 }
257
indices_iter_f<E>(shape: E) -> IndicesIterF<E::Dim> where E: IntoDimension,258 pub fn indices_iter_f<E>(shape: E) -> IndicesIterF<E::Dim>
259 where
260 E: IntoDimension,
261 {
262 let dim = shape.into_dimension();
263 let zero = E::Dim::zeros(dim.ndim());
264 IndicesIterF {
265 has_remaining: dim.size_checked() != Some(0),
266 index: zero,
267 dim,
268 }
269 }
270
271 impl<D> Iterator for IndicesIterF<D>
272 where
273 D: Dimension,
274 {
275 type Item = D::Pattern;
276 #[inline]
next(&mut self) -> Option<Self::Item>277 fn next(&mut self) -> Option<Self::Item> {
278 if !self.has_remaining {
279 None
280 } else {
281 let elt = self.index.clone().into_pattern();
282 self.has_remaining = self.dim.next_for_f(&mut self.index);
283 Some(elt)
284 }
285 }
286
size_hint(&self) -> (usize, Option<usize>)287 fn size_hint(&self) -> (usize, Option<usize>) {
288 if !self.has_remaining {
289 return (0, Some(0));
290 }
291 let gone = self
292 .dim
293 .fortran_strides()
294 .slice()
295 .iter()
296 .zip(self.index.slice().iter())
297 .fold(0, |s, (&a, &b)| s + a as usize * b as usize);
298 let l = self.dim.size() - gone;
299 (l, Some(l))
300 }
301 }
302
303 impl<D> ExactSizeIterator for IndicesIterF<D> where D: Dimension {}
304
305 #[cfg(test)]
306 mod tests {
307 use super::indices;
308 use super::indices_iter_f;
309
310 #[test]
test_indices_iter_c_size_hint()311 fn test_indices_iter_c_size_hint() {
312 let dim = (3, 4);
313 let mut it = indices(dim).into_iter();
314 let mut len = dim.0 * dim.1;
315 assert_eq!(it.len(), len);
316 while let Some(_) = it.next() {
317 len -= 1;
318 assert_eq!(it.len(), len);
319 }
320 assert_eq!(len, 0);
321 }
322
323 #[test]
test_indices_iter_c_fold()324 fn test_indices_iter_c_fold() {
325 macro_rules! run_test {
326 ($dim:expr) => {
327 for num_consume in 0..3 {
328 let mut it = indices($dim).into_iter();
329 for _ in 0..num_consume {
330 it.next();
331 }
332 let clone = it.clone();
333 let len = it.len();
334 let acc = clone.fold(0, |acc, ix| {
335 assert_eq!(ix, it.next().unwrap());
336 acc + 1
337 });
338 assert_eq!(acc, len);
339 assert!(it.next().is_none());
340 }
341 };
342 }
343 run_test!(());
344 run_test!((2,));
345 run_test!((2, 3));
346 run_test!((2, 0, 3));
347 run_test!((2, 3, 4));
348 run_test!((2, 3, 4, 2));
349 }
350
351 #[test]
test_indices_iter_f_size_hint()352 fn test_indices_iter_f_size_hint() {
353 let dim = (3, 4);
354 let mut it = indices_iter_f(dim);
355 let mut len = dim.0 * dim.1;
356 assert_eq!(it.len(), len);
357 while let Some(_) = it.next() {
358 len -= 1;
359 assert_eq!(it.len(), len);
360 }
361 assert_eq!(len, 0);
362 }
363 }
364