1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H 11 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H 12 13 namespace Eigen { 14 15 namespace internal { 16 17 enum { 18 Rhs = 0, 19 Lhs = 1 20 }; 21 22 /* 23 * Implementation of the Eigen blas_data_mapper class for tensors. 24 */ 25 26 template <typename Tensor, bool HasRawAccess> struct CoeffLoader { 27 enum { 28 DirectOffsets = false 29 }; 30 CoeffLoaderCoeffLoader31 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffLoader(const Tensor& tensor) : m_tensor(tensor) { } 32 offsetBufferCoeffLoader33 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void offsetBuffer(typename Tensor::Index) { 34 eigen_assert(false && "unsupported"); 35 } 36 coeffCoeffLoader37 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename Tensor::Scalar coeff(typename Tensor::Index index) const { return m_tensor.coeff(index); } 38 39 template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE packetCoeffLoader40 typename Tensor::PacketReturnType packet(typename Tensor::Index index) const 41 { 42 return m_tensor.template packet<LoadMode>(index); 43 } 44 45 46 private: 47 const Tensor m_tensor; 48 }; 49 50 template <typename Tensor> struct CoeffLoader<Tensor, true> { 51 enum { 52 DirectOffsets = true 53 }; 54 55 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE CoeffLoader(const Tensor& tensor) : m_data(tensor.data()) {} 56 57 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void offsetBuffer(typename Tensor::Index offset) { 58 m_data += offset; 59 } 60 61 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename Tensor::Scalar coeff(typename Tensor::Index index) const { return loadConstant(m_data+index); } 62 63 template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE 64 typename Tensor::PacketReturnType packet(typename Tensor::Index index) const 65 { 66 return internal::ploadt_ro<typename Tensor::PacketReturnType, LoadMode>(m_data + index); 67 } 68 private: 69 typedef typename Tensor::Scalar Scalar; 70 const Scalar* m_data; 71 }; 72 73 template<typename Scalar, typename Index, int side, 74 typename Tensor, 75 typename nocontract_t, typename contract_t, 76 int packet_size, bool inner_dim_contiguous, int Alignment> 77 class SimpleTensorContractionMapper { 78 public: 79 EIGEN_DEVICE_FUNC 80 SimpleTensorContractionMapper(const Tensor& tensor, 81 const nocontract_t& nocontract_strides, 82 const nocontract_t& ij_strides, 83 const contract_t& contract_strides, 84 const contract_t& k_strides) : 85 m_tensor(tensor), 86 m_nocontract_strides(nocontract_strides), 87 m_ij_strides(ij_strides), 88 m_contract_strides(contract_strides), 89 m_k_strides(k_strides) { } 90 91 enum { 92 DirectOffsets = CoeffLoader<Tensor, Tensor::RawAccess>::DirectOffsets 93 }; 94 95 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void offsetBuffer(typename Tensor::Index offset) { 96 m_tensor.offsetBuffer(offset); 97 } 98 99 EIGEN_DEVICE_FUNC 100 EIGEN_STRONG_INLINE void prefetch(Index /*i*/) { } 101 102 EIGEN_DEVICE_FUNC 103 EIGEN_STRONG_INLINE Scalar operator()(Index row) const { 104 // column major assumption 105 return operator()(row, 0); 106 } 107 108 EIGEN_DEVICE_FUNC 109 EIGEN_STRONG_INLINE Scalar operator()(Index row, Index col) const { 110 return m_tensor.coeff(computeIndex(row, col)); 111 } 112 113 EIGEN_DEVICE_FUNC 114 EIGEN_STRONG_INLINE Index computeIndex(Index row, Index col) const { 115 const bool left = (side == Lhs); 116 EIGEN_UNUSED_VARIABLE(left); // annoying bug in g++8.1: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=85963 117 Index nocontract_val = left ? row : col; 118 Index linidx = 0; 119 for (int i = static_cast<int>(array_size<nocontract_t>::value) - 1; i > 0; i--) { 120 const Index idx = nocontract_val / m_ij_strides[i]; 121 linidx += idx * m_nocontract_strides[i]; 122 nocontract_val -= idx * m_ij_strides[i]; 123 } 124 if (array_size<typename Tensor::Dimensions>::value > array_size<contract_t>::value) { 125 if (side == Lhs && inner_dim_contiguous) { 126 eigen_assert(m_nocontract_strides[0] == 1); 127 linidx += nocontract_val; 128 } else { 129 linidx += nocontract_val * m_nocontract_strides[0]; 130 } 131 } 132 133 Index contract_val = left ? col : row; 134 if(array_size<contract_t>::value > 0) { 135 for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) { 136 const Index idx = contract_val / m_k_strides[i]; 137 linidx += idx * m_contract_strides[i]; 138 contract_val -= idx * m_k_strides[i]; 139 } 140 141 if (side == Rhs && inner_dim_contiguous) { 142 eigen_assert(m_contract_strides[0] == 1); 143 linidx += contract_val; 144 } else { 145 linidx += contract_val * m_contract_strides[0]; 146 } 147 } 148 149 return linidx; 150 } 151 152 EIGEN_DEVICE_FUNC 153 EIGEN_STRONG_INLINE IndexPair<Index> computeIndexPair(Index row, Index col, const Index distance) const { 154 const bool left = (side == Lhs); 155 EIGEN_UNUSED_VARIABLE(left); // annoying bug in g++8.1: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=85963 156 Index nocontract_val[2] = {left ? row : col, left ? row + distance : col}; 157 Index linidx[2] = {0, 0}; 158 if (array_size<typename Tensor::Dimensions>::value > array_size<contract_t>::value) { 159 for (int i = static_cast<int>(array_size<nocontract_t>::value) - 1; i > 0; i--) { 160 const Index idx0 = nocontract_val[0] / m_ij_strides[i]; 161 const Index idx1 = nocontract_val[1] / m_ij_strides[i]; 162 linidx[0] += idx0 * m_nocontract_strides[i]; 163 linidx[1] += idx1 * m_nocontract_strides[i]; 164 nocontract_val[0] -= idx0 * m_ij_strides[i]; 165 nocontract_val[1] -= idx1 * m_ij_strides[i]; 166 } 167 if (side == Lhs && inner_dim_contiguous) { 168 eigen_assert(m_nocontract_strides[0] == 1); 169 linidx[0] += nocontract_val[0]; 170 linidx[1] += nocontract_val[1]; 171 } else { 172 linidx[0] += nocontract_val[0] * m_nocontract_strides[0]; 173 linidx[1] += nocontract_val[1] * m_nocontract_strides[0]; 174 } 175 } 176 177 Index contract_val[2] = {left ? col : row, left ? col : row + distance}; 178 if (array_size<contract_t>::value> 0) { 179 for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) { 180 const Index idx0 = contract_val[0] / m_k_strides[i]; 181 const Index idx1 = contract_val[1] / m_k_strides[i]; 182 linidx[0] += idx0 * m_contract_strides[i]; 183 linidx[1] += idx1 * m_contract_strides[i]; 184 contract_val[0] -= idx0 * m_k_strides[i]; 185 contract_val[1] -= idx1 * m_k_strides[i]; 186 } 187 188 if (side == Rhs && inner_dim_contiguous) { 189 eigen_assert(m_contract_strides[0] == 1); 190 linidx[0] += contract_val[0]; 191 linidx[1] += contract_val[1]; 192 } else { 193 linidx[0] += contract_val[0] * m_contract_strides[0]; 194 linidx[1] += contract_val[1] * m_contract_strides[0]; 195 } 196 } 197 return IndexPair<Index>(linidx[0], linidx[1]); 198 } 199 200 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index firstAligned(Index size) const { 201 // Only claim alignment when we can compute the actual stride (ie when we're 202 // dealing with the lhs with inner_dim_contiguous. This is because the 203 // matrix-vector product relies on the stride when dealing with aligned inputs. 204 return (Alignment == Aligned) && (side == Lhs) && inner_dim_contiguous ? 0 : size; 205 } 206 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index stride() const { 207 return ((side == Lhs) && inner_dim_contiguous && array_size<contract_t>::value > 0) ? m_contract_strides[0] : 1; 208 } 209 210 protected: 211 CoeffLoader<Tensor, Tensor::RawAccess> m_tensor; 212 const nocontract_t m_nocontract_strides; 213 const nocontract_t m_ij_strides; 214 const contract_t m_contract_strides; 215 const contract_t m_k_strides; 216 }; 217 218 219 template<typename Scalar, typename Index, int side, 220 typename Tensor, 221 typename nocontract_t, typename contract_t, 222 int packet_size, bool inner_dim_contiguous, 223 bool inner_dim_reordered, int Alignment> 224 class BaseTensorContractionMapper : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, Alignment> 225 { 226 public: 227 typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, Alignment> ParentMapper; 228 229 EIGEN_DEVICE_FUNC 230 BaseTensorContractionMapper(const Tensor& tensor, 231 const nocontract_t& nocontract_strides, 232 const nocontract_t& ij_strides, 233 const contract_t& contract_strides, 234 const contract_t& k_strides) : 235 ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { } 236 237 typedef typename Tensor::PacketReturnType Packet; 238 typedef typename unpacket_traits<Packet>::half HalfPacket; 239 240 template <int AlignmentType> 241 EIGEN_DEVICE_FUNC 242 EIGEN_STRONG_INLINE Packet loadPacket(Index i, Index j) const { 243 // whole method makes column major assumption 244 245 // don't need to add offsets for now (because operator handles that) 246 // current code assumes packet size must be a multiple of 2 247 EIGEN_STATIC_ASSERT(packet_size % 2 == 0, YOU_MADE_A_PROGRAMMING_MISTAKE); 248 249 if (Tensor::PacketAccess && inner_dim_contiguous && !inner_dim_reordered) { 250 const Index index = this->computeIndex(i, j); 251 eigen_assert(this->computeIndex(i+packet_size-1, j) == index + packet_size-1); 252 return this->m_tensor.template packet<AlignmentType>(index); 253 } 254 255 const IndexPair<Index> indexPair = this->computeIndexPair(i, j, packet_size - 1); 256 const Index first = indexPair.first; 257 const Index last = indexPair.second; 258 259 // We can always do optimized packet reads from left hand side right now, because 260 // the vertical matrix dimension on the left hand side is never contracting. 261 // On the right hand side we need to check if the contracting dimensions may have 262 // been shuffled first. 263 if (Tensor::PacketAccess && 264 (side == Lhs || internal::array_size<contract_t>::value <= 1 || !inner_dim_reordered) && 265 (last - first) == (packet_size - 1)) { 266 267 return this->m_tensor.template packet<AlignmentType>(first); 268 } 269 270 EIGEN_ALIGN_MAX Scalar data[packet_size]; 271 272 data[0] = this->m_tensor.coeff(first); 273 for (Index k = 1; k < packet_size - 1; k += 2) { 274 const IndexPair<Index> internal_pair = this->computeIndexPair(i + k, j, 1); 275 data[k] = this->m_tensor.coeff(internal_pair.first); 276 data[k + 1] = this->m_tensor.coeff(internal_pair.second); 277 } 278 data[packet_size - 1] = this->m_tensor.coeff(last); 279 280 return pload<Packet>(data); 281 } 282 283 template <int AlignmentType> 284 EIGEN_DEVICE_FUNC 285 EIGEN_STRONG_INLINE HalfPacket loadHalfPacket(Index i, Index j) const { 286 // whole method makes column major assumption 287 288 // don't need to add offsets for now (because operator handles that) 289 const Index half_packet_size = unpacket_traits<HalfPacket>::size; 290 if (half_packet_size == packet_size) { 291 return loadPacket<AlignmentType>(i, j); 292 } 293 EIGEN_ALIGN_MAX Scalar data[half_packet_size]; 294 for (Index k = 0; k < half_packet_size; k++) { 295 data[k] = operator()(i + k, j); 296 } 297 return pload<HalfPacket>(data); 298 } 299 }; 300 301 302 template<typename Scalar, typename Index, int side, 303 typename Tensor, 304 typename nocontract_t, typename contract_t, 305 bool inner_dim_contiguous, 306 bool inner_dim_reordered, int Alignment> 307 class BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment> : public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, Alignment> 308 { 309 public: 310 typedef SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, Alignment> ParentMapper; 311 312 EIGEN_DEVICE_FUNC 313 BaseTensorContractionMapper(const Tensor& tensor, 314 const nocontract_t& nocontract_strides, 315 const nocontract_t& ij_strides, 316 const contract_t& contract_strides, 317 const contract_t& k_strides) : 318 ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { } 319 320 typedef typename Tensor::PacketReturnType Packet; 321 template <int> EIGEN_DEVICE_FUNC 322 EIGEN_STRONG_INLINE Packet loadPacket(Index i, Index j) const { 323 EIGEN_ALIGN_MAX Scalar data[1]; 324 data[0] = this->m_tensor.coeff(this->computeIndex(i, j)); 325 return pload<typename Tensor::PacketReturnType>(data); 326 } 327 template <int> EIGEN_DEVICE_FUNC 328 EIGEN_STRONG_INLINE Packet loadHalfPacket(Index i, Index j) const { 329 return loadPacket(i, j); 330 } 331 }; 332 333 334 template<typename Scalar, typename Index, int side, 335 typename Tensor, 336 typename nocontract_t, typename contract_t, 337 int packet_size, 338 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment> 339 class TensorContractionSubMapper { 340 public: 341 typedef typename Tensor::PacketReturnType Packet; 342 typedef typename unpacket_traits<Packet>::half HalfPacket; 343 344 typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> ParentMapper; 345 typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> Self; 346 typedef Self LinearMapper; 347 348 enum { 349 // We can use direct offsets iff the parent mapper supports then and we can compute the strides. 350 // TODO: we should also enable direct offsets for the Rhs case. 351 UseDirectOffsets = ParentMapper::DirectOffsets && (side == Lhs) && inner_dim_contiguous && (array_size<contract_t>::value > 0) 352 }; 353 354 EIGEN_DEVICE_FUNC TensorContractionSubMapper(const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset) 355 : m_base_mapper(base_mapper), m_vert_offset(vert_offset), m_horiz_offset(horiz_offset) { 356 // Bake the offsets into the buffer used by the base mapper whenever possible. This avoids the need to recompute 357 // this offset every time we attempt to access a coefficient. 358 if (UseDirectOffsets) { 359 Index stride = m_base_mapper.stride(); 360 m_base_mapper.offsetBuffer(vert_offset + horiz_offset * stride); 361 } 362 } 363 364 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const { 365 if (UseDirectOffsets) { 366 return m_base_mapper(i, 0); 367 } 368 return m_base_mapper(i + m_vert_offset, m_horiz_offset); 369 } 370 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, Index j) const { 371 if (UseDirectOffsets) { 372 return m_base_mapper(i, j); 373 } 374 return m_base_mapper(i + m_vert_offset, j + m_horiz_offset); 375 } 376 377 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const { 378 if (UseDirectOffsets) { 379 return m_base_mapper.template loadPacket<Alignment>(i, 0); 380 } 381 return m_base_mapper.template loadPacket<Alignment>(i + m_vert_offset, m_horiz_offset); 382 } 383 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j) const { 384 if (UseDirectOffsets) { 385 return m_base_mapper.template loadPacket<Alignment>(i, j); 386 } 387 return m_base_mapper.template loadPacket<Alignment>(i + m_vert_offset, j + m_horiz_offset); 388 } 389 390 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i) const { 391 if (UseDirectOffsets) { 392 return m_base_mapper.template loadHalfPacket<Alignment>(i, 0); 393 } 394 return m_base_mapper.template loadHalfPacket<Alignment>(i + m_vert_offset, m_horiz_offset); 395 } 396 397 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, Packet p) const { 398 if (UseDirectOffsets) { 399 m_base_mapper.storePacket(i, 0, p); 400 } 401 m_base_mapper.storePacket(i + m_vert_offset, m_horiz_offset, p); 402 } 403 404 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const { 405 if (UseDirectOffsets) { 406 return LinearMapper(m_base_mapper, i, j); 407 } 408 return LinearMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset); 409 } 410 411 template <typename PacketT, int AlignmentType> 412 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(Index i) const { 413 EIGEN_STATIC_ASSERT((internal::is_same<PacketT, Packet>::value), YOU_MADE_A_PROGRAMMING_MISTAKE); 414 const int ActualAlignment = (AlignmentType == Aligned) && (Alignment == Aligned) ? Aligned : Unaligned; 415 if (UseDirectOffsets) { 416 return m_base_mapper.template loadPacket<ActualAlignment>(i, 0); 417 } 418 return m_base_mapper.template loadPacket<ActualAlignment>(i + m_vert_offset, m_horiz_offset); 419 } 420 421 template <typename Packet> 422 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool aligned(Index) const { 423 return false; 424 } 425 426 private: 427 ParentMapper m_base_mapper; 428 const Index m_vert_offset; 429 const Index m_horiz_offset; 430 }; 431 432 433 template<typename Scalar_, typename Index, int side, 434 typename Tensor, 435 typename nocontract_t, typename contract_t, 436 int packet_size, 437 bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment> 438 class TensorContractionInputMapper 439 : public BaseTensorContractionMapper<Scalar_, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> { 440 441 public: 442 typedef Scalar_ Scalar; 443 typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> Base; 444 typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> SubMapper; 445 typedef SubMapper VectorMapper; 446 447 EIGEN_DEVICE_FUNC TensorContractionInputMapper(const Tensor& tensor, 448 const nocontract_t& nocontract_strides, 449 const nocontract_t& ij_strides, 450 const contract_t& contract_strides, 451 const contract_t& k_strides) 452 : Base(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { } 453 454 EIGEN_DEVICE_FUNC 455 EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const { 456 return SubMapper(*this, i, j); 457 } 458 459 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const { 460 return VectorMapper(*this, i, j); 461 } 462 }; 463 464 465 466 } // end namespace internal 467 } // end namespace Eigen 468 469 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H 470