1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_MAP_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_MAP_H
12 
13 namespace Eigen {
14 
15 // FIXME use proper doxygen documentation (e.g. \tparam MakePointer_)
16 
17 /** \class TensorMap
18   * \ingroup CXX11_Tensor_Module
19   *
20   * \brief A tensor expression mapping an existing array of data.
21   *
22   */
23 /// `template <class> class MakePointer_` is added to convert the host pointer to the device pointer.
24 /// It is added due to the fact that for our device compiler `T*` is not allowed.
25 /// If we wanted to use the same Evaluator functions we have to convert that type to our pointer `T`.
26 /// This is done through our `MakePointer_` class. By default the Type in the `MakePointer_<T>` is `T*` .
27 /// Therefore, by adding the default value, we managed to convert the type and it does not break any
28 /// existing code as its default value is `T*`.
29 template<typename PlainObjectType, int Options_, template <class> class MakePointer_> class TensorMap : public TensorBase<TensorMap<PlainObjectType, Options_, MakePointer_> >
30 {
31   public:
32     typedef TensorMap<PlainObjectType, Options_, MakePointer_> Self;
33     typedef typename PlainObjectType::Base Base;
34     typedef typename Eigen::internal::nested<Self>::type Nested;
35     typedef typename internal::traits<PlainObjectType>::StorageKind StorageKind;
36     typedef typename internal::traits<PlainObjectType>::Index Index;
37     typedef typename internal::traits<PlainObjectType>::Scalar Scalar;
38     typedef typename NumTraits<Scalar>::Real RealScalar;
39     typedef typename Base::CoeffReturnType CoeffReturnType;
40 
41   /*    typedef typename internal::conditional<
42                          bool(internal::is_lvalue<PlainObjectType>::value),
43                          Scalar *,
44                          const Scalar *>::type
45                      PointerType;*/
46     typedef typename MakePointer_<Scalar>::Type PointerType;
47     typedef PointerType PointerArgType;
48 
49     static const int Options = Options_;
50 
51     static const Index NumIndices = PlainObjectType::NumIndices;
52     typedef typename PlainObjectType::Dimensions Dimensions;
53 
54     enum {
55       IsAligned = ((int(Options_)&Aligned)==Aligned),
56       Layout = PlainObjectType::Layout,
57       CoordAccess = true,
58       RawAccess = true
59     };
60 
61     EIGEN_DEVICE_FUNC
TensorMap(PointerArgType dataPtr)62     EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr) : m_data(dataPtr), m_dimensions() {
63       // The number of dimensions used to construct a tensor must be equal to the rank of the tensor.
64       EIGEN_STATIC_ASSERT((0 == NumIndices || NumIndices == Dynamic), YOU_MADE_A_PROGRAMMING_MISTAKE)
65     }
66 
67 #if EIGEN_HAS_VARIADIC_TEMPLATES
68     template<typename... IndexTypes> EIGEN_DEVICE_FUNC
TensorMap(PointerArgType dataPtr,Index firstDimension,IndexTypes...otherDimensions)69     EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index firstDimension, IndexTypes... otherDimensions) : m_data(dataPtr), m_dimensions(firstDimension, otherDimensions...) {
70       // The number of dimensions used to construct a tensor must be equal to the rank of the tensor.
71       EIGEN_STATIC_ASSERT((sizeof...(otherDimensions) + 1 == NumIndices || NumIndices == Dynamic), YOU_MADE_A_PROGRAMMING_MISTAKE)
72     }
73 #else
74     EIGEN_DEVICE_FUNC
TensorMap(PointerArgType dataPtr,Index firstDimension)75     EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index firstDimension) : m_data(dataPtr), m_dimensions(firstDimension) {
76       // The number of dimensions used to construct a tensor must be equal to the rank of the tensor.
77       EIGEN_STATIC_ASSERT((1 == NumIndices || NumIndices == Dynamic), YOU_MADE_A_PROGRAMMING_MISTAKE)
78     }
79     EIGEN_DEVICE_FUNC
TensorMap(PointerArgType dataPtr,Index dim1,Index dim2)80     EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index dim1, Index dim2) : m_data(dataPtr), m_dimensions(dim1, dim2) {
81       EIGEN_STATIC_ASSERT(2 == NumIndices || NumIndices == Dynamic, YOU_MADE_A_PROGRAMMING_MISTAKE)
82     }
83     EIGEN_DEVICE_FUNC
TensorMap(PointerArgType dataPtr,Index dim1,Index dim2,Index dim3)84     EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index dim1, Index dim2, Index dim3) : m_data(dataPtr), m_dimensions(dim1, dim2, dim3) {
85       EIGEN_STATIC_ASSERT(3 == NumIndices || NumIndices == Dynamic, YOU_MADE_A_PROGRAMMING_MISTAKE)
86     }
87     EIGEN_DEVICE_FUNC
TensorMap(PointerArgType dataPtr,Index dim1,Index dim2,Index dim3,Index dim4)88     EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index dim1, Index dim2, Index dim3, Index dim4) : m_data(dataPtr), m_dimensions(dim1, dim2, dim3, dim4) {
89       EIGEN_STATIC_ASSERT(4 == NumIndices || NumIndices == Dynamic, YOU_MADE_A_PROGRAMMING_MISTAKE)
90     }
91     EIGEN_DEVICE_FUNC
TensorMap(PointerArgType dataPtr,Index dim1,Index dim2,Index dim3,Index dim4,Index dim5)92     EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index dim1, Index dim2, Index dim3, Index dim4, Index dim5) : m_data(dataPtr), m_dimensions(dim1, dim2, dim3, dim4, dim5) {
93       EIGEN_STATIC_ASSERT(5 == NumIndices || NumIndices == Dynamic, YOU_MADE_A_PROGRAMMING_MISTAKE)
94     }
95 #endif
96 
TensorMap(PointerArgType dataPtr,const array<Index,NumIndices> & dimensions)97    EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, const array<Index, NumIndices>& dimensions)
98       : m_data(dataPtr), m_dimensions(dimensions)
99     { }
100 
101     template <typename Dimensions>
TensorMap(PointerArgType dataPtr,const Dimensions & dimensions)102     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, const Dimensions& dimensions)
103       : m_data(dataPtr), m_dimensions(dimensions)
104     { }
105 
TensorMap(PlainObjectType & tensor)106     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorMap(PlainObjectType& tensor)
107       : m_data(tensor.data()), m_dimensions(tensor.dimensions())
108     { }
109 
110     EIGEN_DEVICE_FUNC
rank()111     EIGEN_STRONG_INLINE Index rank() const { return m_dimensions.rank(); }
112     EIGEN_DEVICE_FUNC
dimension(Index n)113     EIGEN_STRONG_INLINE Index dimension(Index n) const { return m_dimensions[n]; }
114     EIGEN_DEVICE_FUNC
dimensions()115     EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
116     EIGEN_DEVICE_FUNC
size()117     EIGEN_STRONG_INLINE Index size() const { return m_dimensions.TotalSize(); }
118     EIGEN_DEVICE_FUNC
data()119     EIGEN_STRONG_INLINE PointerType data() { return m_data; }
120     EIGEN_DEVICE_FUNC
data()121     EIGEN_STRONG_INLINE const PointerType data() const { return m_data; }
122 
123     EIGEN_DEVICE_FUNC
operator()124     EIGEN_STRONG_INLINE const Scalar& operator()(const array<Index, NumIndices>& indices) const
125     {
126       //      eigen_assert(checkIndexRange(indices));
127       if (PlainObjectType::Options&RowMajor) {
128         const Index index = m_dimensions.IndexOfRowMajor(indices);
129         return m_data[index];
130       } else {
131         const Index index = m_dimensions.IndexOfColMajor(indices);
132         return m_data[index];
133       }
134     }
135 
136     EIGEN_DEVICE_FUNC
operator()137     EIGEN_STRONG_INLINE const Scalar& operator()() const
138     {
139       EIGEN_STATIC_ASSERT(NumIndices == 0, YOU_MADE_A_PROGRAMMING_MISTAKE)
140       return m_data[0];
141     }
142 
143     EIGEN_DEVICE_FUNC
operator()144     EIGEN_STRONG_INLINE const Scalar& operator()(Index index) const
145     {
146       eigen_internal_assert(index >= 0 && index < size());
147       return m_data[index];
148     }
149 
150 #if EIGEN_HAS_VARIADIC_TEMPLATES
151     template<typename... IndexTypes> EIGEN_DEVICE_FUNC
operator()152     EIGEN_STRONG_INLINE const Scalar& operator()(Index firstIndex, Index secondIndex, IndexTypes... otherIndices) const
153     {
154       EIGEN_STATIC_ASSERT(sizeof...(otherIndices) + 2 == NumIndices, YOU_MADE_A_PROGRAMMING_MISTAKE)
155       if (PlainObjectType::Options&RowMajor) {
156         const Index index = m_dimensions.IndexOfRowMajor(array<Index, NumIndices>{{firstIndex, secondIndex, otherIndices...}});
157         return m_data[index];
158       } else {
159         const Index index = m_dimensions.IndexOfColMajor(array<Index, NumIndices>{{firstIndex, secondIndex, otherIndices...}});
160         return m_data[index];
161       }
162     }
163 #else
164     EIGEN_DEVICE_FUNC
operator()165     EIGEN_STRONG_INLINE const Scalar& operator()(Index i0, Index i1) const
166     {
167       if (PlainObjectType::Options&RowMajor) {
168         const Index index = i1 + i0 * m_dimensions[1];
169         return m_data[index];
170       } else {
171         const Index index = i0 + i1 * m_dimensions[0];
172         return m_data[index];
173       }
174     }
175     EIGEN_DEVICE_FUNC
operator()176     EIGEN_STRONG_INLINE const Scalar& operator()(Index i0, Index i1, Index i2) const
177     {
178       if (PlainObjectType::Options&RowMajor) {
179          const Index index = i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0);
180          return m_data[index];
181       } else {
182          const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * i2);
183         return m_data[index];
184       }
185     }
186     EIGEN_DEVICE_FUNC
operator()187     EIGEN_STRONG_INLINE const Scalar& operator()(Index i0, Index i1, Index i2, Index i3) const
188     {
189       if (PlainObjectType::Options&RowMajor) {
190         const Index index = i3 + m_dimensions[3] * (i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0));
191         return m_data[index];
192       } else {
193         const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * (i2 + m_dimensions[2] * i3));
194         return m_data[index];
195       }
196     }
197     EIGEN_DEVICE_FUNC
operator()198     EIGEN_STRONG_INLINE const Scalar& operator()(Index i0, Index i1, Index i2, Index i3, Index i4) const
199     {
200       if (PlainObjectType::Options&RowMajor) {
201         const Index index = i4 + m_dimensions[4] * (i3 + m_dimensions[3] * (i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0)));
202         return m_data[index];
203       } else {
204         const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * (i2 + m_dimensions[2] * (i3 + m_dimensions[3] * i4)));
205         return m_data[index];
206       }
207     }
208 #endif
209 
210     EIGEN_DEVICE_FUNC
operator()211     EIGEN_STRONG_INLINE Scalar& operator()(const array<Index, NumIndices>& indices)
212     {
213       //      eigen_assert(checkIndexRange(indices));
214       if (PlainObjectType::Options&RowMajor) {
215         const Index index = m_dimensions.IndexOfRowMajor(indices);
216         return m_data[index];
217       } else {
218         const Index index = m_dimensions.IndexOfColMajor(indices);
219         return m_data[index];
220       }
221     }
222 
223     EIGEN_DEVICE_FUNC
operator()224     EIGEN_STRONG_INLINE Scalar& operator()()
225     {
226       EIGEN_STATIC_ASSERT(NumIndices == 0, YOU_MADE_A_PROGRAMMING_MISTAKE)
227       return m_data[0];
228     }
229 
230     EIGEN_DEVICE_FUNC
operator()231     EIGEN_STRONG_INLINE Scalar& operator()(Index index)
232     {
233       eigen_internal_assert(index >= 0 && index < size());
234       return m_data[index];
235     }
236 
237 #if EIGEN_HAS_VARIADIC_TEMPLATES
238     template<typename... IndexTypes> EIGEN_DEVICE_FUNC
operator()239     EIGEN_STRONG_INLINE Scalar& operator()(Index firstIndex, Index secondIndex, IndexTypes... otherIndices)
240     {
241       static_assert(sizeof...(otherIndices) + 2 == NumIndices || NumIndices == Dynamic, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor.");
242       const std::size_t NumDims = sizeof...(otherIndices) + 2;
243       if (PlainObjectType::Options&RowMajor) {
244         const Index index = m_dimensions.IndexOfRowMajor(array<Index, NumDims>{{firstIndex, secondIndex, otherIndices...}});
245         return m_data[index];
246       } else {
247         const Index index = m_dimensions.IndexOfColMajor(array<Index, NumDims>{{firstIndex, secondIndex, otherIndices...}});
248         return m_data[index];
249       }
250     }
251 #else
252     EIGEN_DEVICE_FUNC
operator()253     EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1)
254     {
255        if (PlainObjectType::Options&RowMajor) {
256          const Index index = i1 + i0 * m_dimensions[1];
257         return m_data[index];
258       } else {
259         const Index index = i0 + i1 * m_dimensions[0];
260         return m_data[index];
261       }
262     }
263     EIGEN_DEVICE_FUNC
operator()264     EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1, Index i2)
265     {
266        if (PlainObjectType::Options&RowMajor) {
267          const Index index = i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0);
268         return m_data[index];
269       } else {
270          const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * i2);
271         return m_data[index];
272       }
273     }
274     EIGEN_DEVICE_FUNC
operator()275     EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1, Index i2, Index i3)
276     {
277       if (PlainObjectType::Options&RowMajor) {
278         const Index index = i3 + m_dimensions[3] * (i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0));
279         return m_data[index];
280       } else {
281         const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * (i2 + m_dimensions[2] * i3));
282         return m_data[index];
283       }
284     }
285     EIGEN_DEVICE_FUNC
operator()286     EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1, Index i2, Index i3, Index i4)
287     {
288       if (PlainObjectType::Options&RowMajor) {
289         const Index index = i4 + m_dimensions[4] * (i3 + m_dimensions[3] * (i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0)));
290         return m_data[index];
291       } else {
292         const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * (i2 + m_dimensions[2] * (i3 + m_dimensions[3] * i4)));
293         return m_data[index];
294       }
295     }
296 #endif
297 
298     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Self& operator=(const Self& other)
299     {
300       typedef TensorAssignOp<Self, const Self> Assign;
301       Assign assign(*this, other);
302       internal::TensorExecutor<const Assign, DefaultDevice>::run(assign, DefaultDevice());
303       return *this;
304     }
305 
306     template<typename OtherDerived>
307     EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
308     Self& operator=(const OtherDerived& other)
309     {
310       typedef TensorAssignOp<Self, const OtherDerived> Assign;
311       Assign assign(*this, other);
312       internal::TensorExecutor<const Assign, DefaultDevice>::run(assign, DefaultDevice());
313       return *this;
314     }
315 
316   private:
317     typename MakePointer_<Scalar>::Type m_data;
318     Dimensions m_dimensions;
319 };
320 
321 } // end namespace Eigen
322 
323 #endif // EIGEN_CXX11_TENSOR_TENSOR_MAP_H
324