1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_EVALUATOR_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_EVALUATOR_H
12 
13 namespace Eigen {
14 
15 /** \class TensorEvaluator
16   * \ingroup CXX11_Tensor_Module
17   *
18   * \brief The tensor evaluator classes.
19   *
20   * These classes are responsible for the evaluation of the tensor expression.
21   *
22   * TODO: add support for more types of expressions, in particular expressions
23   * leading to lvalues (slicing, reshaping, etc...)
24   */
25 
26 // Generic evaluator
27 template<typename Derived, typename Device>
28 struct TensorEvaluator
29 {
30   typedef typename Derived::Index Index;
31   typedef typename Derived::Scalar Scalar;
32   typedef typename Derived::Scalar CoeffReturnType;
33   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
34   typedef typename Derived::Dimensions Dimensions;
35 
36   // NumDimensions is -1 for variable dim tensors
37   static const int NumCoords = internal::traits<Derived>::NumDimensions > 0 ?
38                                internal::traits<Derived>::NumDimensions : 0;
39 
40   enum {
41     IsAligned = Derived::IsAligned,
42     PacketAccess = (internal::unpacket_traits<PacketReturnType>::size > 1),
43     Layout = Derived::Layout,
44     CoordAccess = NumCoords > 0,
45     RawAccess = true
46   };
47 
TensorEvaluatorTensorEvaluator48   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const Derived& m, const Device& device)
49       : m_data(const_cast<typename internal::traits<Derived>::template MakePointer<Scalar>::Type>(m.data())), m_dims(m.dimensions()), m_device(device), m_impl(m)
50   { }
51 
52   // Used for accessor extraction in SYCL Managed TensorMap:
derivedTensorEvaluator53   const Derived& derived() const { return m_impl; }
dimensionsTensorEvaluator54   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dims; }
55 
evalSubExprsIfNeededTensorEvaluator56   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* dest) {
57     if (dest) {
58       m_device.memcpy((void*)dest, m_data, sizeof(Scalar) * m_dims.TotalSize());
59       return false;
60     }
61     return true;
62   }
63 
cleanupTensorEvaluator64   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { }
65 
coeffTensorEvaluator66   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
67     eigen_assert(m_data);
68     return m_data[index];
69   }
70 
coeffRefTensorEvaluator71   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) {
72     eigen_assert(m_data);
73     return m_data[index];
74   }
75 
76   template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
packetTensorEvaluator77   PacketReturnType packet(Index index) const
78   {
79     return internal::ploadt<PacketReturnType, LoadMode>(m_data + index);
80   }
81 
82   template <int StoreMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
writePacketTensorEvaluator83   void writePacket(Index index, const PacketReturnType& x)
84   {
85     return internal::pstoret<Scalar, PacketReturnType, StoreMode>(m_data + index, x);
86   }
87 
coeffTensorEvaluator88   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(const array<DenseIndex, NumCoords>& coords) const {
89     eigen_assert(m_data);
90     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
91       return m_data[m_dims.IndexOfColMajor(coords)];
92     } else {
93       return m_data[m_dims.IndexOfRowMajor(coords)];
94     }
95   }
96 
coeffRefTensorEvaluator97   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(const array<DenseIndex, NumCoords>& coords) {
98     eigen_assert(m_data);
99     if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
100       return m_data[m_dims.IndexOfColMajor(coords)];
101     } else {
102       return m_data[m_dims.IndexOfRowMajor(coords)];
103     }
104   }
105 
costPerCoeffTensorEvaluator106   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
107     return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized,
108                         internal::unpacket_traits<PacketReturnType>::size);
109   }
110 
dataTensorEvaluator111   EIGEN_DEVICE_FUNC typename internal::traits<Derived>::template MakePointer<Scalar>::Type data() const { return m_data; }
112 
113   /// required by sycl in order to construct sycl buffer from raw pointer
deviceTensorEvaluator114   const Device& device() const{return m_device;}
115 
116  protected:
117   typename internal::traits<Derived>::template MakePointer<Scalar>::Type m_data;
118   Dimensions m_dims;
119   const Device& m_device;
120   const Derived& m_impl;
121 };
122 
123 namespace {
124 template <typename T> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
loadConstant(const T * address)125 T loadConstant(const T* address) {
126   return *address;
127 }
128 // Use the texture cache on CUDA devices whenever possible
129 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 350
130 template <> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
loadConstant(const float * address)131 float loadConstant(const float* address) {
132   return __ldg(address);
133 }
134 template <> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
loadConstant(const double * address)135 double loadConstant(const double* address) {
136   return __ldg(address);
137 }
138 template <> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
loadConstant(const Eigen::half * address)139 Eigen::half loadConstant(const Eigen::half* address) {
140   return Eigen::half(half_impl::raw_uint16_to_half(__ldg(&address->x)));
141 }
142 #endif
143 }
144 
145 
146 // Default evaluator for rvalues
147 template<typename Derived, typename Device>
148 struct TensorEvaluator<const Derived, Device>
149 {
150   typedef typename Derived::Index Index;
151   typedef typename Derived::Scalar Scalar;
152   typedef typename Derived::Scalar CoeffReturnType;
153   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
154   typedef typename Derived::Dimensions Dimensions;
155 
156   // NumDimensions is -1 for variable dim tensors
157   static const int NumCoords = internal::traits<Derived>::NumDimensions > 0 ?
158                                internal::traits<Derived>::NumDimensions : 0;
159 
160   enum {
161     IsAligned = Derived::IsAligned,
162     PacketAccess = (internal::unpacket_traits<PacketReturnType>::size > 1),
163     Layout = Derived::Layout,
164     CoordAccess = NumCoords > 0,
165     RawAccess = true
166   };
167 
168   // Used for accessor extraction in SYCL Managed TensorMap:
169   const Derived& derived() const { return m_impl; }
170 
171   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const Derived& m, const Device& device)
172       : m_data(m.data()), m_dims(m.dimensions()), m_device(device), m_impl(m)
173   { }
174 
175   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dims; }
176 
177   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* data) {
178     if (!NumTraits<typename internal::remove_const<Scalar>::type>::RequireInitialization && data) {
179       m_device.memcpy((void*)data, m_data, m_dims.TotalSize() * sizeof(Scalar));
180       return false;
181     }
182     return true;
183   }
184 
185   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { }
186 
187   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
188     eigen_assert(m_data);
189     return loadConstant(m_data+index);
190   }
191 
192   template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
193   PacketReturnType packet(Index index) const
194   {
195     return internal::ploadt_ro<PacketReturnType, LoadMode>(m_data + index);
196   }
197 
198   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(const array<DenseIndex, NumCoords>& coords) const {
199     eigen_assert(m_data);
200     const Index index = (static_cast<int>(Layout) == static_cast<int>(ColMajor)) ? m_dims.IndexOfColMajor(coords)
201                         : m_dims.IndexOfRowMajor(coords);
202     return loadConstant(m_data+index);
203   }
204 
205   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
206     return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized,
207                         internal::unpacket_traits<PacketReturnType>::size);
208   }
209 
210   EIGEN_DEVICE_FUNC typename internal::traits<Derived>::template MakePointer<const Scalar>::Type data() const { return m_data; }
211 
212   /// added for sycl in order to construct the buffer from the sycl device
213   const Device& device() const{return m_device;}
214 
215  protected:
216   typename internal::traits<Derived>::template MakePointer<const Scalar>::Type m_data;
217   Dimensions m_dims;
218   const Device& m_device;
219   const Derived& m_impl;
220 };
221 
222 
223 
224 
225 // -------------------- CwiseNullaryOp --------------------
226 
227 template<typename NullaryOp, typename ArgType, typename Device>
228 struct TensorEvaluator<const TensorCwiseNullaryOp<NullaryOp, ArgType>, Device>
229 {
230   typedef TensorCwiseNullaryOp<NullaryOp, ArgType> XprType;
231 
232   enum {
233     IsAligned = true,
234     PacketAccess = internal::functor_traits<NullaryOp>::PacketAccess,
235     Layout = TensorEvaluator<ArgType, Device>::Layout,
236     CoordAccess = false,  // to be implemented
237     RawAccess = false
238   };
239 
240   EIGEN_DEVICE_FUNC
241   TensorEvaluator(const XprType& op, const Device& device)
242       : m_functor(op.functor()), m_argImpl(op.nestedExpression(), device), m_wrapper()
243   { }
244 
245   typedef typename XprType::Index Index;
246   typedef typename XprType::Scalar Scalar;
247   typedef typename internal::traits<XprType>::Scalar CoeffReturnType;
248   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
249   static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
250   typedef typename TensorEvaluator<ArgType, Device>::Dimensions Dimensions;
251 
252   EIGEN_DEVICE_FUNC const Dimensions& dimensions() const { return m_argImpl.dimensions(); }
253 
254   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType*) { return true; }
255   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { }
256 
257   EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
258   {
259     return m_wrapper(m_functor, index);
260   }
261 
262   template<int LoadMode>
263   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
264   {
265     return m_wrapper.template packetOp<PacketReturnType, Index>(m_functor, index);
266   }
267 
268   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
269   costPerCoeff(bool vectorized) const {
270     return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized,
271                         internal::unpacket_traits<PacketReturnType>::size);
272   }
273 
274   EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return NULL; }
275 
276   /// required by sycl in order to extract the accessor
277   const TensorEvaluator<ArgType, Device>& impl() const { return m_argImpl; }
278   /// required by sycl in order to extract the accessor
279   NullaryOp functor() const { return m_functor; }
280 
281 
282  private:
283   const NullaryOp m_functor;
284   TensorEvaluator<ArgType, Device> m_argImpl;
285   const internal::nullary_wrapper<CoeffReturnType,NullaryOp> m_wrapper;
286 };
287 
288 
289 
290 // -------------------- CwiseUnaryOp --------------------
291 
292 template<typename UnaryOp, typename ArgType, typename Device>
293 struct TensorEvaluator<const TensorCwiseUnaryOp<UnaryOp, ArgType>, Device>
294 {
295   typedef TensorCwiseUnaryOp<UnaryOp, ArgType> XprType;
296 
297   enum {
298     IsAligned = TensorEvaluator<ArgType, Device>::IsAligned,
299     PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess & internal::functor_traits<UnaryOp>::PacketAccess,
300     Layout = TensorEvaluator<ArgType, Device>::Layout,
301     CoordAccess = false,  // to be implemented
302     RawAccess = false
303   };
304 
305   EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device)
306     : m_functor(op.functor()),
307       m_argImpl(op.nestedExpression(), device)
308   { }
309 
310   typedef typename XprType::Index Index;
311   typedef typename XprType::Scalar Scalar;
312   typedef typename internal::traits<XprType>::Scalar CoeffReturnType;
313   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
314   static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
315   typedef typename TensorEvaluator<ArgType, Device>::Dimensions Dimensions;
316 
317   EIGEN_DEVICE_FUNC const Dimensions& dimensions() const { return m_argImpl.dimensions(); }
318 
319   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar*) {
320     m_argImpl.evalSubExprsIfNeeded(NULL);
321     return true;
322   }
323   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
324     m_argImpl.cleanup();
325   }
326 
327   EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
328   {
329     return m_functor(m_argImpl.coeff(index));
330   }
331 
332   template<int LoadMode>
333   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
334   {
335     return m_functor.packetOp(m_argImpl.template packet<LoadMode>(index));
336   }
337 
338   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost costPerCoeff(bool vectorized) const {
339     const double functor_cost = internal::functor_traits<UnaryOp>::Cost;
340     return m_argImpl.costPerCoeff(vectorized) +
341         TensorOpCost(0, 0, functor_cost, vectorized, PacketSize);
342   }
343 
344   EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return NULL; }
345 
346   /// required by sycl in order to extract the accessor
347   const TensorEvaluator<ArgType, Device> & impl() const { return m_argImpl; }
348   /// added for sycl in order to construct the buffer from sycl device
349   UnaryOp functor() const { return m_functor; }
350 
351 
352  private:
353   const UnaryOp m_functor;
354   TensorEvaluator<ArgType, Device> m_argImpl;
355 };
356 
357 
358 // -------------------- CwiseBinaryOp --------------------
359 
360 template<typename BinaryOp, typename LeftArgType, typename RightArgType, typename Device>
361 struct TensorEvaluator<const TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArgType>, Device>
362 {
363   typedef TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArgType> XprType;
364 
365   enum {
366     IsAligned = TensorEvaluator<LeftArgType, Device>::IsAligned & TensorEvaluator<RightArgType, Device>::IsAligned,
367     PacketAccess = TensorEvaluator<LeftArgType, Device>::PacketAccess & TensorEvaluator<RightArgType, Device>::PacketAccess &
368                    internal::functor_traits<BinaryOp>::PacketAccess,
369     Layout = TensorEvaluator<LeftArgType, Device>::Layout,
370     CoordAccess = false,  // to be implemented
371     RawAccess = false
372   };
373 
374   EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device)
375     : m_functor(op.functor()),
376       m_leftImpl(op.lhsExpression(), device),
377       m_rightImpl(op.rhsExpression(), device)
378   {
379     EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<LeftArgType, Device>::Layout) == static_cast<int>(TensorEvaluator<RightArgType, Device>::Layout) || internal::traits<XprType>::NumDimensions <= 1), YOU_MADE_A_PROGRAMMING_MISTAKE);
380     eigen_assert(dimensions_match(m_leftImpl.dimensions(), m_rightImpl.dimensions()));
381   }
382 
383   typedef typename XprType::Index Index;
384   typedef typename XprType::Scalar Scalar;
385   typedef typename internal::traits<XprType>::Scalar CoeffReturnType;
386   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
387   static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
388   typedef typename TensorEvaluator<LeftArgType, Device>::Dimensions Dimensions;
389 
390   EIGEN_DEVICE_FUNC const Dimensions& dimensions() const
391   {
392     // TODO: use right impl instead if right impl dimensions are known at compile time.
393     return m_leftImpl.dimensions();
394   }
395 
396   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType*) {
397     m_leftImpl.evalSubExprsIfNeeded(NULL);
398     m_rightImpl.evalSubExprsIfNeeded(NULL);
399     return true;
400   }
401   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
402     m_leftImpl.cleanup();
403     m_rightImpl.cleanup();
404   }
405 
406   EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
407   {
408     return m_functor(m_leftImpl.coeff(index), m_rightImpl.coeff(index));
409   }
410   template<int LoadMode>
411   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
412   {
413     return m_functor.packetOp(m_leftImpl.template packet<LoadMode>(index), m_rightImpl.template packet<LoadMode>(index));
414   }
415 
416   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
417   costPerCoeff(bool vectorized) const {
418     const double functor_cost = internal::functor_traits<BinaryOp>::Cost;
419     return m_leftImpl.costPerCoeff(vectorized) +
420            m_rightImpl.costPerCoeff(vectorized) +
421            TensorOpCost(0, 0, functor_cost, vectorized, PacketSize);
422   }
423 
424   EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return NULL; }
425   /// required by sycl in order to extract the accessor
426   const TensorEvaluator<LeftArgType, Device>& left_impl() const { return m_leftImpl; }
427   /// required by sycl in order to extract the accessor
428   const TensorEvaluator<RightArgType, Device>& right_impl() const { return m_rightImpl; }
429   /// required by sycl in order to extract the accessor
430   BinaryOp functor() const { return m_functor; }
431 
432  private:
433   const BinaryOp m_functor;
434   TensorEvaluator<LeftArgType, Device> m_leftImpl;
435   TensorEvaluator<RightArgType, Device> m_rightImpl;
436 };
437 
438 // -------------------- CwiseTernaryOp --------------------
439 
440 template<typename TernaryOp, typename Arg1Type, typename Arg2Type, typename Arg3Type, typename Device>
441 struct TensorEvaluator<const TensorCwiseTernaryOp<TernaryOp, Arg1Type, Arg2Type, Arg3Type>, Device>
442 {
443   typedef TensorCwiseTernaryOp<TernaryOp, Arg1Type, Arg2Type, Arg3Type> XprType;
444 
445   enum {
446     IsAligned = TensorEvaluator<Arg1Type, Device>::IsAligned & TensorEvaluator<Arg2Type, Device>::IsAligned & TensorEvaluator<Arg3Type, Device>::IsAligned,
447     PacketAccess = TensorEvaluator<Arg1Type, Device>::PacketAccess & TensorEvaluator<Arg2Type, Device>::PacketAccess & TensorEvaluator<Arg3Type, Device>::PacketAccess &
448                    internal::functor_traits<TernaryOp>::PacketAccess,
449     Layout = TensorEvaluator<Arg1Type, Device>::Layout,
450     CoordAccess = false,  // to be implemented
451     RawAccess = false
452   };
453 
454   EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device)
455     : m_functor(op.functor()),
456       m_arg1Impl(op.arg1Expression(), device),
457       m_arg2Impl(op.arg2Expression(), device),
458       m_arg3Impl(op.arg3Expression(), device)
459   {
460     EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<Arg1Type, Device>::Layout) == static_cast<int>(TensorEvaluator<Arg3Type, Device>::Layout) || internal::traits<XprType>::NumDimensions <= 1), YOU_MADE_A_PROGRAMMING_MISTAKE);
461 
462     EIGEN_STATIC_ASSERT((internal::is_same<typename internal::traits<Arg1Type>::StorageKind,
463                          typename internal::traits<Arg2Type>::StorageKind>::value),
464                         STORAGE_KIND_MUST_MATCH)
465     EIGEN_STATIC_ASSERT((internal::is_same<typename internal::traits<Arg1Type>::StorageKind,
466                          typename internal::traits<Arg3Type>::StorageKind>::value),
467                         STORAGE_KIND_MUST_MATCH)
468     EIGEN_STATIC_ASSERT((internal::is_same<typename internal::traits<Arg1Type>::Index,
469                          typename internal::traits<Arg2Type>::Index>::value),
470                         STORAGE_INDEX_MUST_MATCH)
471     EIGEN_STATIC_ASSERT((internal::is_same<typename internal::traits<Arg1Type>::Index,
472                          typename internal::traits<Arg3Type>::Index>::value),
473                         STORAGE_INDEX_MUST_MATCH)
474 
475     eigen_assert(dimensions_match(m_arg1Impl.dimensions(), m_arg2Impl.dimensions()) && dimensions_match(m_arg1Impl.dimensions(), m_arg3Impl.dimensions()));
476   }
477 
478   typedef typename XprType::Index Index;
479   typedef typename XprType::Scalar Scalar;
480   typedef typename internal::traits<XprType>::Scalar CoeffReturnType;
481   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
482   static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
483   typedef typename TensorEvaluator<Arg1Type, Device>::Dimensions Dimensions;
484 
485   EIGEN_DEVICE_FUNC const Dimensions& dimensions() const
486   {
487     // TODO: use arg2 or arg3 dimensions if they are known at compile time.
488     return m_arg1Impl.dimensions();
489   }
490 
491   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType*) {
492     m_arg1Impl.evalSubExprsIfNeeded(NULL);
493     m_arg2Impl.evalSubExprsIfNeeded(NULL);
494     m_arg3Impl.evalSubExprsIfNeeded(NULL);
495     return true;
496   }
497   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
498     m_arg1Impl.cleanup();
499     m_arg2Impl.cleanup();
500     m_arg3Impl.cleanup();
501   }
502 
503   EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
504   {
505     return m_functor(m_arg1Impl.coeff(index), m_arg2Impl.coeff(index), m_arg3Impl.coeff(index));
506   }
507   template<int LoadMode>
508   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
509   {
510     return m_functor.packetOp(m_arg1Impl.template packet<LoadMode>(index),
511                               m_arg2Impl.template packet<LoadMode>(index),
512                               m_arg3Impl.template packet<LoadMode>(index));
513   }
514 
515   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
516   costPerCoeff(bool vectorized) const {
517     const double functor_cost = internal::functor_traits<TernaryOp>::Cost;
518     return m_arg1Impl.costPerCoeff(vectorized) +
519            m_arg2Impl.costPerCoeff(vectorized) +
520            m_arg3Impl.costPerCoeff(vectorized) +
521            TensorOpCost(0, 0, functor_cost, vectorized, PacketSize);
522   }
523 
524   EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return NULL; }
525 
526   /// required by sycl in order to extract the accessor
527   const TensorEvaluator<Arg1Type, Device> & arg1Impl() const { return m_arg1Impl; }
528   /// required by sycl in order to extract the accessor
529   const TensorEvaluator<Arg2Type, Device>& arg2Impl() const { return m_arg2Impl; }
530   /// required by sycl in order to extract the accessor
531   const TensorEvaluator<Arg3Type, Device>& arg3Impl() const { return m_arg3Impl; }
532 
533  private:
534   const TernaryOp m_functor;
535   TensorEvaluator<Arg1Type, Device> m_arg1Impl;
536   TensorEvaluator<Arg2Type, Device> m_arg2Impl;
537   TensorEvaluator<Arg3Type, Device> m_arg3Impl;
538 };
539 
540 
541 // -------------------- SelectOp --------------------
542 
543 template<typename IfArgType, typename ThenArgType, typename ElseArgType, typename Device>
544 struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType>, Device>
545 {
546   typedef TensorSelectOp<IfArgType, ThenArgType, ElseArgType> XprType;
547   typedef typename XprType::Scalar Scalar;
548 
549   enum {
550     IsAligned = TensorEvaluator<ThenArgType, Device>::IsAligned & TensorEvaluator<ElseArgType, Device>::IsAligned,
551     PacketAccess = TensorEvaluator<ThenArgType, Device>::PacketAccess & TensorEvaluator<ElseArgType, Device>::PacketAccess &
552                    internal::packet_traits<Scalar>::HasBlend,
553     Layout = TensorEvaluator<IfArgType, Device>::Layout,
554     CoordAccess = false,  // to be implemented
555     RawAccess = false
556   };
557 
558   EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device)
559     : m_condImpl(op.ifExpression(), device),
560       m_thenImpl(op.thenExpression(), device),
561       m_elseImpl(op.elseExpression(), device)
562   {
563     EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<IfArgType, Device>::Layout) == static_cast<int>(TensorEvaluator<ThenArgType, Device>::Layout)), YOU_MADE_A_PROGRAMMING_MISTAKE);
564     EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<IfArgType, Device>::Layout) == static_cast<int>(TensorEvaluator<ElseArgType, Device>::Layout)), YOU_MADE_A_PROGRAMMING_MISTAKE);
565     eigen_assert(dimensions_match(m_condImpl.dimensions(), m_thenImpl.dimensions()));
566     eigen_assert(dimensions_match(m_thenImpl.dimensions(), m_elseImpl.dimensions()));
567   }
568 
569   typedef typename XprType::Index Index;
570   typedef typename internal::traits<XprType>::Scalar CoeffReturnType;
571   typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
572   static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
573   typedef typename TensorEvaluator<IfArgType, Device>::Dimensions Dimensions;
574 
575   EIGEN_DEVICE_FUNC const Dimensions& dimensions() const
576   {
577     // TODO: use then or else impl instead if they happen to be known at compile time.
578     return m_condImpl.dimensions();
579   }
580 
581   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType*) {
582     m_condImpl.evalSubExprsIfNeeded(NULL);
583     m_thenImpl.evalSubExprsIfNeeded(NULL);
584     m_elseImpl.evalSubExprsIfNeeded(NULL);
585     return true;
586   }
587   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
588     m_condImpl.cleanup();
589     m_thenImpl.cleanup();
590     m_elseImpl.cleanup();
591   }
592 
593   EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
594   {
595     return m_condImpl.coeff(index) ? m_thenImpl.coeff(index) : m_elseImpl.coeff(index);
596   }
597   template<int LoadMode>
598   EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const
599   {
600     internal::Selector<PacketSize> select;
601     for (Index i = 0; i < PacketSize; ++i) {
602       select.select[i] = m_condImpl.coeff(index+i);
603     }
604     return internal::pblend(select,
605                             m_thenImpl.template packet<LoadMode>(index),
606                             m_elseImpl.template packet<LoadMode>(index));
607   }
608 
609   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
610   costPerCoeff(bool vectorized) const {
611     return m_condImpl.costPerCoeff(vectorized) +
612            m_thenImpl.costPerCoeff(vectorized)
613         .cwiseMax(m_elseImpl.costPerCoeff(vectorized));
614   }
615 
616   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType* data() const { return NULL; }
617   /// required by sycl in order to extract the accessor
618   const TensorEvaluator<IfArgType, Device> & cond_impl() const { return m_condImpl; }
619   /// required by sycl in order to extract the accessor
620   const TensorEvaluator<ThenArgType, Device>& then_impl() const { return m_thenImpl; }
621   /// required by sycl in order to extract the accessor
622   const TensorEvaluator<ElseArgType, Device>& else_impl() const { return m_elseImpl; }
623 
624  private:
625   TensorEvaluator<IfArgType, Device> m_condImpl;
626   TensorEvaluator<ThenArgType, Device> m_thenImpl;
627   TensorEvaluator<ElseArgType, Device> m_elseImpl;
628 };
629 
630 
631 } // end namespace Eigen
632 
633 #endif // EIGEN_CXX11_TENSOR_TENSOR_EVALUATOR_H
634