1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_MAP_H 11 #define EIGEN_CXX11_TENSOR_TENSOR_MAP_H 12 13 namespace Eigen { 14 15 // FIXME use proper doxygen documentation (e.g. \tparam MakePointer_) 16 17 /** \class TensorMap 18 * \ingroup CXX11_Tensor_Module 19 * 20 * \brief A tensor expression mapping an existing array of data. 21 * 22 */ 23 /// `template <class> class MakePointer_` is added to convert the host pointer to the device pointer. 24 /// It is added due to the fact that for our device compiler `T*` is not allowed. 25 /// If we wanted to use the same Evaluator functions we have to convert that type to our pointer `T`. 26 /// This is done through our `MakePointer_` class. By default the Type in the `MakePointer_<T>` is `T*` . 27 /// Therefore, by adding the default value, we managed to convert the type and it does not break any 28 /// existing code as its default value is `T*`. 29 template<typename PlainObjectType, int Options_, template <class> class MakePointer_> class TensorMap : public TensorBase<TensorMap<PlainObjectType, Options_, MakePointer_> > 30 { 31 public: 32 typedef TensorMap<PlainObjectType, Options_, MakePointer_> Self; 33 typedef typename PlainObjectType::Base Base; 34 typedef typename Eigen::internal::nested<Self>::type Nested; 35 typedef typename internal::traits<PlainObjectType>::StorageKind StorageKind; 36 typedef typename internal::traits<PlainObjectType>::Index Index; 37 typedef typename internal::traits<PlainObjectType>::Scalar Scalar; 38 typedef typename NumTraits<Scalar>::Real RealScalar; 39 typedef typename Base::CoeffReturnType CoeffReturnType; 40 41 /* typedef typename internal::conditional< 42 bool(internal::is_lvalue<PlainObjectType>::value), 43 Scalar *, 44 const Scalar *>::type 45 PointerType;*/ 46 typedef typename MakePointer_<Scalar>::Type PointerType; 47 typedef PointerType PointerArgType; 48 49 static const int Options = Options_; 50 51 static const Index NumIndices = PlainObjectType::NumIndices; 52 typedef typename PlainObjectType::Dimensions Dimensions; 53 54 enum { 55 IsAligned = ((int(Options_)&Aligned)==Aligned), 56 Layout = PlainObjectType::Layout, 57 CoordAccess = true, 58 RawAccess = true 59 }; 60 61 EIGEN_DEVICE_FUNC TensorMap(PointerArgType dataPtr)62 EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr) : m_data(dataPtr), m_dimensions() { 63 // The number of dimensions used to construct a tensor must be equal to the rank of the tensor. 64 EIGEN_STATIC_ASSERT((0 == NumIndices || NumIndices == Dynamic), YOU_MADE_A_PROGRAMMING_MISTAKE) 65 } 66 67 #if EIGEN_HAS_VARIADIC_TEMPLATES 68 template<typename... IndexTypes> EIGEN_DEVICE_FUNC TensorMap(PointerArgType dataPtr,Index firstDimension,IndexTypes...otherDimensions)69 EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index firstDimension, IndexTypes... otherDimensions) : m_data(dataPtr), m_dimensions(firstDimension, otherDimensions...) { 70 // The number of dimensions used to construct a tensor must be equal to the rank of the tensor. 71 EIGEN_STATIC_ASSERT((sizeof...(otherDimensions) + 1 == NumIndices || NumIndices == Dynamic), YOU_MADE_A_PROGRAMMING_MISTAKE) 72 } 73 #else 74 EIGEN_DEVICE_FUNC TensorMap(PointerArgType dataPtr,Index firstDimension)75 EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index firstDimension) : m_data(dataPtr), m_dimensions(firstDimension) { 76 // The number of dimensions used to construct a tensor must be equal to the rank of the tensor. 77 EIGEN_STATIC_ASSERT((1 == NumIndices || NumIndices == Dynamic), YOU_MADE_A_PROGRAMMING_MISTAKE) 78 } 79 EIGEN_DEVICE_FUNC TensorMap(PointerArgType dataPtr,Index dim1,Index dim2)80 EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index dim1, Index dim2) : m_data(dataPtr), m_dimensions(dim1, dim2) { 81 EIGEN_STATIC_ASSERT(2 == NumIndices || NumIndices == Dynamic, YOU_MADE_A_PROGRAMMING_MISTAKE) 82 } 83 EIGEN_DEVICE_FUNC TensorMap(PointerArgType dataPtr,Index dim1,Index dim2,Index dim3)84 EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index dim1, Index dim2, Index dim3) : m_data(dataPtr), m_dimensions(dim1, dim2, dim3) { 85 EIGEN_STATIC_ASSERT(3 == NumIndices || NumIndices == Dynamic, YOU_MADE_A_PROGRAMMING_MISTAKE) 86 } 87 EIGEN_DEVICE_FUNC TensorMap(PointerArgType dataPtr,Index dim1,Index dim2,Index dim3,Index dim4)88 EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index dim1, Index dim2, Index dim3, Index dim4) : m_data(dataPtr), m_dimensions(dim1, dim2, dim3, dim4) { 89 EIGEN_STATIC_ASSERT(4 == NumIndices || NumIndices == Dynamic, YOU_MADE_A_PROGRAMMING_MISTAKE) 90 } 91 EIGEN_DEVICE_FUNC TensorMap(PointerArgType dataPtr,Index dim1,Index dim2,Index dim3,Index dim4,Index dim5)92 EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, Index dim1, Index dim2, Index dim3, Index dim4, Index dim5) : m_data(dataPtr), m_dimensions(dim1, dim2, dim3, dim4, dim5) { 93 EIGEN_STATIC_ASSERT(5 == NumIndices || NumIndices == Dynamic, YOU_MADE_A_PROGRAMMING_MISTAKE) 94 } 95 #endif 96 TensorMap(PointerArgType dataPtr,const array<Index,NumIndices> & dimensions)97 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, const array<Index, NumIndices>& dimensions) 98 : m_data(dataPtr), m_dimensions(dimensions) 99 { } 100 101 template <typename Dimensions> TensorMap(PointerArgType dataPtr,const Dimensions & dimensions)102 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, const Dimensions& dimensions) 103 : m_data(dataPtr), m_dimensions(dimensions) 104 { } 105 TensorMap(PlainObjectType & tensor)106 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorMap(PlainObjectType& tensor) 107 : m_data(tensor.data()), m_dimensions(tensor.dimensions()) 108 { } 109 110 EIGEN_DEVICE_FUNC rank()111 EIGEN_STRONG_INLINE Index rank() const { return m_dimensions.rank(); } 112 EIGEN_DEVICE_FUNC dimension(Index n)113 EIGEN_STRONG_INLINE Index dimension(Index n) const { return m_dimensions[n]; } 114 EIGEN_DEVICE_FUNC dimensions()115 EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } 116 EIGEN_DEVICE_FUNC size()117 EIGEN_STRONG_INLINE Index size() const { return m_dimensions.TotalSize(); } 118 EIGEN_DEVICE_FUNC data()119 EIGEN_STRONG_INLINE PointerType data() { return m_data; } 120 EIGEN_DEVICE_FUNC data()121 EIGEN_STRONG_INLINE const PointerType data() const { return m_data; } 122 123 EIGEN_DEVICE_FUNC operator()124 EIGEN_STRONG_INLINE const Scalar& operator()(const array<Index, NumIndices>& indices) const 125 { 126 // eigen_assert(checkIndexRange(indices)); 127 if (PlainObjectType::Options&RowMajor) { 128 const Index index = m_dimensions.IndexOfRowMajor(indices); 129 return m_data[index]; 130 } else { 131 const Index index = m_dimensions.IndexOfColMajor(indices); 132 return m_data[index]; 133 } 134 } 135 136 EIGEN_DEVICE_FUNC operator()137 EIGEN_STRONG_INLINE const Scalar& operator()() const 138 { 139 EIGEN_STATIC_ASSERT(NumIndices == 0, YOU_MADE_A_PROGRAMMING_MISTAKE) 140 return m_data[0]; 141 } 142 143 EIGEN_DEVICE_FUNC operator()144 EIGEN_STRONG_INLINE const Scalar& operator()(Index index) const 145 { 146 eigen_internal_assert(index >= 0 && index < size()); 147 return m_data[index]; 148 } 149 150 #if EIGEN_HAS_VARIADIC_TEMPLATES 151 template<typename... IndexTypes> EIGEN_DEVICE_FUNC operator()152 EIGEN_STRONG_INLINE const Scalar& operator()(Index firstIndex, Index secondIndex, IndexTypes... otherIndices) const 153 { 154 EIGEN_STATIC_ASSERT(sizeof...(otherIndices) + 2 == NumIndices, YOU_MADE_A_PROGRAMMING_MISTAKE) 155 if (PlainObjectType::Options&RowMajor) { 156 const Index index = m_dimensions.IndexOfRowMajor(array<Index, NumIndices>{{firstIndex, secondIndex, otherIndices...}}); 157 return m_data[index]; 158 } else { 159 const Index index = m_dimensions.IndexOfColMajor(array<Index, NumIndices>{{firstIndex, secondIndex, otherIndices...}}); 160 return m_data[index]; 161 } 162 } 163 #else 164 EIGEN_DEVICE_FUNC operator()165 EIGEN_STRONG_INLINE const Scalar& operator()(Index i0, Index i1) const 166 { 167 if (PlainObjectType::Options&RowMajor) { 168 const Index index = i1 + i0 * m_dimensions[1]; 169 return m_data[index]; 170 } else { 171 const Index index = i0 + i1 * m_dimensions[0]; 172 return m_data[index]; 173 } 174 } 175 EIGEN_DEVICE_FUNC operator()176 EIGEN_STRONG_INLINE const Scalar& operator()(Index i0, Index i1, Index i2) const 177 { 178 if (PlainObjectType::Options&RowMajor) { 179 const Index index = i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0); 180 return m_data[index]; 181 } else { 182 const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * i2); 183 return m_data[index]; 184 } 185 } 186 EIGEN_DEVICE_FUNC operator()187 EIGEN_STRONG_INLINE const Scalar& operator()(Index i0, Index i1, Index i2, Index i3) const 188 { 189 if (PlainObjectType::Options&RowMajor) { 190 const Index index = i3 + m_dimensions[3] * (i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0)); 191 return m_data[index]; 192 } else { 193 const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * (i2 + m_dimensions[2] * i3)); 194 return m_data[index]; 195 } 196 } 197 EIGEN_DEVICE_FUNC operator()198 EIGEN_STRONG_INLINE const Scalar& operator()(Index i0, Index i1, Index i2, Index i3, Index i4) const 199 { 200 if (PlainObjectType::Options&RowMajor) { 201 const Index index = i4 + m_dimensions[4] * (i3 + m_dimensions[3] * (i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0))); 202 return m_data[index]; 203 } else { 204 const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * (i2 + m_dimensions[2] * (i3 + m_dimensions[3] * i4))); 205 return m_data[index]; 206 } 207 } 208 #endif 209 210 EIGEN_DEVICE_FUNC operator()211 EIGEN_STRONG_INLINE Scalar& operator()(const array<Index, NumIndices>& indices) 212 { 213 // eigen_assert(checkIndexRange(indices)); 214 if (PlainObjectType::Options&RowMajor) { 215 const Index index = m_dimensions.IndexOfRowMajor(indices); 216 return m_data[index]; 217 } else { 218 const Index index = m_dimensions.IndexOfColMajor(indices); 219 return m_data[index]; 220 } 221 } 222 223 EIGEN_DEVICE_FUNC operator()224 EIGEN_STRONG_INLINE Scalar& operator()() 225 { 226 EIGEN_STATIC_ASSERT(NumIndices == 0, YOU_MADE_A_PROGRAMMING_MISTAKE) 227 return m_data[0]; 228 } 229 230 EIGEN_DEVICE_FUNC operator()231 EIGEN_STRONG_INLINE Scalar& operator()(Index index) 232 { 233 eigen_internal_assert(index >= 0 && index < size()); 234 return m_data[index]; 235 } 236 237 #if EIGEN_HAS_VARIADIC_TEMPLATES 238 template<typename... IndexTypes> EIGEN_DEVICE_FUNC operator()239 EIGEN_STRONG_INLINE Scalar& operator()(Index firstIndex, Index secondIndex, IndexTypes... otherIndices) 240 { 241 static_assert(sizeof...(otherIndices) + 2 == NumIndices || NumIndices == Dynamic, "Number of indices used to access a tensor coefficient must be equal to the rank of the tensor."); 242 const std::size_t NumDims = sizeof...(otherIndices) + 2; 243 if (PlainObjectType::Options&RowMajor) { 244 const Index index = m_dimensions.IndexOfRowMajor(array<Index, NumDims>{{firstIndex, secondIndex, otherIndices...}}); 245 return m_data[index]; 246 } else { 247 const Index index = m_dimensions.IndexOfColMajor(array<Index, NumDims>{{firstIndex, secondIndex, otherIndices...}}); 248 return m_data[index]; 249 } 250 } 251 #else 252 EIGEN_DEVICE_FUNC operator()253 EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1) 254 { 255 if (PlainObjectType::Options&RowMajor) { 256 const Index index = i1 + i0 * m_dimensions[1]; 257 return m_data[index]; 258 } else { 259 const Index index = i0 + i1 * m_dimensions[0]; 260 return m_data[index]; 261 } 262 } 263 EIGEN_DEVICE_FUNC operator()264 EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1, Index i2) 265 { 266 if (PlainObjectType::Options&RowMajor) { 267 const Index index = i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0); 268 return m_data[index]; 269 } else { 270 const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * i2); 271 return m_data[index]; 272 } 273 } 274 EIGEN_DEVICE_FUNC operator()275 EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1, Index i2, Index i3) 276 { 277 if (PlainObjectType::Options&RowMajor) { 278 const Index index = i3 + m_dimensions[3] * (i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0)); 279 return m_data[index]; 280 } else { 281 const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * (i2 + m_dimensions[2] * i3)); 282 return m_data[index]; 283 } 284 } 285 EIGEN_DEVICE_FUNC operator()286 EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1, Index i2, Index i3, Index i4) 287 { 288 if (PlainObjectType::Options&RowMajor) { 289 const Index index = i4 + m_dimensions[4] * (i3 + m_dimensions[3] * (i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0))); 290 return m_data[index]; 291 } else { 292 const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * (i2 + m_dimensions[2] * (i3 + m_dimensions[3] * i4))); 293 return m_data[index]; 294 } 295 } 296 #endif 297 298 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Self& operator=(const Self& other) 299 { 300 typedef TensorAssignOp<Self, const Self> Assign; 301 Assign assign(*this, other); 302 internal::TensorExecutor<const Assign, DefaultDevice>::run(assign, DefaultDevice()); 303 return *this; 304 } 305 306 template<typename OtherDerived> 307 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE 308 Self& operator=(const OtherDerived& other) 309 { 310 typedef TensorAssignOp<Self, const OtherDerived> Assign; 311 Assign assign(*this, other); 312 internal::TensorExecutor<const Assign, DefaultDevice>::run(assign, DefaultDevice()); 313 return *this; 314 } 315 316 private: 317 typename MakePointer_<Scalar>::Type m_data; 318 Dimensions m_dimensions; 319 }; 320 321 } // end namespace Eigen 322 323 #endif // EIGEN_CXX11_TENSOR_TENSOR_MAP_H 324