1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Mehdi Goli    Codeplay Software Ltd.
5 // Ralph Potter  Codeplay Software Ltd.
6 // Luke Iwanski  Codeplay Software Ltd.
7 // Contact: <eigen@codeplay.com>
8 //
9 // This Source Code Form is subject to the terms of the Mozilla
10 // Public License v. 2.0. If a copy of the MPL was not distributed
11 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
12 
13 /*****************************************************************
14  * TensorSyclextractFunctors.h
15  *
16  * \brief:
17  *  Used to extract all the functors allocated to each node of the expression
18 *tree.
19  *
20 *****************************************************************/
21 
22 #ifndef UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_EXTRACT_FUNCTORS_HPP
23 #define UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_EXTRACT_FUNCTORS_HPP
24 
25 namespace Eigen {
26 namespace TensorSycl {
27 namespace internal {
28 /// \struct FunctorExtractor:  This struct is used to extract the functors
29 /// constructed on
30 /// the host-side, to pack them and reuse them in reconstruction of the
31 /// expression on the device.
32 /// We have to do that as in Eigen the functors are not stateless so we cannot
33 /// re-instantiate them on the device.
34 /// We have to pass instantiated functors to the device.
35 // This struct is used for leafNode (TensorMap) and nodes behaving like leafNode (TensorForcedEval).
36 template <typename Evaluator> struct FunctorExtractor{
37   typedef typename Evaluator::Dimensions Dimensions;
38   const Dimensions m_dimensions;
dimensionsFunctorExtractor39   const Dimensions& dimensions() const { return m_dimensions; }
FunctorExtractorFunctorExtractor40   FunctorExtractor(const Evaluator& expr)
41   : m_dimensions(expr.dimensions()) {}
42 
43 };
44 
45 /// specialisation of the \ref FunctorExtractor struct when the node type is
46 /// const TensorCwiseNullaryOp, const TensorCwiseUnaryOp, and const TensorBroadcastingOp
47 template <template <class, class> class UnaryCategory, typename OP, typename RHSExpr, typename Dev>
48 struct FunctorExtractor<TensorEvaluator<const UnaryCategory<OP, RHSExpr>, Dev> > {
49   FunctorExtractor<TensorEvaluator<RHSExpr, Dev> > rhsExpr;
50   OP func;
51   FunctorExtractor(const TensorEvaluator<const UnaryCategory<OP, RHSExpr>, Dev>& expr)
52   : rhsExpr(expr.impl()), func(expr.functor()) {}
53 };
54 /// specialisation of the \ref FunctorExtractor struct when the node type is
55 /// TensorCwiseNullaryOp, TensorCwiseUnaryOp, and TensorBroadcastingOp
56 template <template <class, class> class UnaryCategory, typename OP, typename RHSExpr, typename Dev>
57 struct FunctorExtractor<TensorEvaluator<UnaryCategory<OP, RHSExpr>, Dev> >
58 : FunctorExtractor<TensorEvaluator<const UnaryCategory<OP, RHSExpr>, Dev> >{};
59 
60 /// specialisation of the \ref FunctorExtractor struct when the node type is
61 /// const TensorCwiseBinaryOp
62 template <template<class, class, class> class BinaryCategory, typename OP, typename LHSExpr, typename RHSExpr, typename Dev>
63 struct FunctorExtractor<TensorEvaluator<const BinaryCategory<OP, LHSExpr, RHSExpr>, Dev> > {
64   FunctorExtractor<TensorEvaluator<LHSExpr, Dev> > lhsExpr;
65   FunctorExtractor<TensorEvaluator<RHSExpr, Dev> > rhsExpr;
66   OP func;
67   FunctorExtractor(const TensorEvaluator<const BinaryCategory<OP, LHSExpr, RHSExpr>, Dev>& expr)
68   : lhsExpr(expr.left_impl()),rhsExpr(expr.right_impl()),func(expr.functor()) {}
69 };
70 
71 /// specialisation of the \ref FunctorExtractor struct when the node type is
72 /// const TensorCwiseBinaryOp
73 template <template <class, class, class> class BinaryCategory, typename OP, typename LHSExpr, typename RHSExpr, typename Dev>
74 struct FunctorExtractor<TensorEvaluator<BinaryCategory<OP,  LHSExpr, RHSExpr>, Dev> >
75 : FunctorExtractor<TensorEvaluator<const BinaryCategory<OP,  LHSExpr, RHSExpr>, Dev> >{};
76 
77 /// specialisation of the \ref FunctorExtractor struct when the node type is
78 /// const TensorCwiseTernaryOp
79 template <template <class, class, class, class> class TernaryCategory, typename OP, typename Arg1Expr, typename Arg2Expr, typename Arg3Expr,typename Dev>
80 struct FunctorExtractor<TensorEvaluator<const TernaryCategory<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Dev> > {
81   FunctorExtractor<TensorEvaluator<Arg1Expr, Dev> > arg1Expr;
82   FunctorExtractor<TensorEvaluator<Arg2Expr, Dev> > arg2Expr;
83   FunctorExtractor<TensorEvaluator<Arg3Expr, Dev> > arg3Expr;
84   OP func;
85   FunctorExtractor(const TensorEvaluator<const TernaryCategory<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Dev>& expr)
86   : arg1Expr(expr.arg1Impl()), arg2Expr(expr.arg2Impl()), arg3Expr(expr.arg3Impl()), func(expr.functor()) {}
87 };
88 
89 /// specialisation of the \ref FunctorExtractor struct when the node type is
90 /// TensorCwiseTernaryOp
91 template <template <class, class, class, class> class TernaryCategory, typename OP, typename Arg1Expr, typename Arg2Expr, typename Arg3Expr, typename Dev>
92 struct FunctorExtractor<TensorEvaluator< TernaryCategory<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Dev> >
93 :FunctorExtractor<TensorEvaluator<const TernaryCategory<OP, Arg1Expr, Arg2Expr, Arg3Expr>, Dev> >{};
94 
95 /// specialisation of the \ref FunctorExtractor struct when the node type is
96 /// const TensorCwiseSelectOp. This is an specialisation without OP so it has to be separated.
97 template <typename IfExpr, typename ThenExpr, typename ElseExpr, typename Dev>
98 struct FunctorExtractor< TensorEvaluator<const TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Dev> > {
99   FunctorExtractor<TensorEvaluator<IfExpr, Dev> > ifExpr;
100   FunctorExtractor<TensorEvaluator<ThenExpr, Dev> > thenExpr;
101   FunctorExtractor<TensorEvaluator<ElseExpr, Dev> > elseExpr;
102   FunctorExtractor(const TensorEvaluator<const TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Dev>& expr)
103   : ifExpr(expr.cond_impl()), thenExpr(expr.then_impl()), elseExpr(expr.else_impl()) {}
104 };
105 
106 /// specialisation of the \ref FunctorExtractor struct when the node type is
107 /// TensorCwiseSelectOp. This is an specialisation without OP so it has to be separated
108 template <typename IfExpr, typename ThenExpr, typename ElseExpr, typename Dev>
109 struct FunctorExtractor<TensorEvaluator<TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Dev> >
110 :FunctorExtractor< TensorEvaluator<const TensorSelectOp<IfExpr, ThenExpr, ElseExpr>, Dev> > {};
111 
112 /// specialisation of the \ref FunctorExtractor struct when the node type is
113 /// const TensorAssignOp. This is an specialisation without OP so it has to be separated.
114 template <typename LHSExpr, typename RHSExpr, typename Dev>
115 struct FunctorExtractor<TensorEvaluator<const TensorAssignOp<LHSExpr, RHSExpr>, Dev> > {
116   FunctorExtractor<TensorEvaluator<LHSExpr, Dev> > lhsExpr;
117   FunctorExtractor<TensorEvaluator<RHSExpr, Dev> > rhsExpr;
118   FunctorExtractor(const TensorEvaluator<const TensorAssignOp<LHSExpr, RHSExpr>, Dev>& expr)
119   : lhsExpr(expr.left_impl()), rhsExpr(expr.right_impl()) {}
120 };
121 
122 /// specialisation of the \ref FunctorExtractor struct when the node type is
123 /// TensorAssignOp. This is an specialisation without OP so it has to be separated.
124 template <typename LHSExpr, typename RHSExpr, typename Dev>
125 struct FunctorExtractor<TensorEvaluator<TensorAssignOp<LHSExpr, RHSExpr>, Dev> >
126 :FunctorExtractor<TensorEvaluator<const TensorAssignOp<LHSExpr, RHSExpr>, Dev> >{};
127 
128 
129 /// specialisation of the \ref FunctorExtractor struct when the node type is
130 /// const TensorEvalToOp, This is an specialisation without OP so it has to be separated.
131 template <typename RHSExpr, typename Dev>
132 struct FunctorExtractor<TensorEvaluator<const TensorEvalToOp<RHSExpr>, Dev> > {
133   FunctorExtractor<TensorEvaluator<RHSExpr, Dev> > rhsExpr;
134   FunctorExtractor(const TensorEvaluator<const TensorEvalToOp<RHSExpr>, Dev>& expr)
135   : rhsExpr(expr.impl()) {}
136 };
137 
138 /// specialisation of the \ref FunctorExtractor struct when the node type is
139 /// TensorEvalToOp. This is a specialisation without OP so it has to be separated.
140 template <typename RHSExpr, typename Dev>
141 struct FunctorExtractor<TensorEvaluator<TensorEvalToOp<RHSExpr>, Dev> >
142 : FunctorExtractor<TensorEvaluator<const TensorEvalToOp<RHSExpr>, Dev> > {};
143 
144 template<typename Dim, size_t NumOutputDim> struct DimConstr {
145 template<typename InDim>
146   static inline Dim getDim(InDim dims ) {return dims;}
147 };
148 
149 template<typename Dim> struct DimConstr<Dim, 0> {
150   template<typename InDim>
151     static inline Dim getDim(InDim dims ) {return Dim(dims.TotalSize());}
152 };
153 
154 template<typename Op, typename Dims, typename ArgType, template <class> class MakePointer_, typename Device>
155 struct FunctorExtractor<TensorEvaluator<const TensorReductionOp<Op, Dims, ArgType, MakePointer_>, Device>>{
156   typedef TensorEvaluator<const TensorReductionOp<Op, Dims, ArgType, MakePointer_>, Device> Evaluator;
157   typedef typename Eigen::internal::conditional<Evaluator::NumOutputDims==0, DSizes<typename Evaluator::Index, 1>, typename Evaluator::Dimensions >::type Dimensions;
158   const Dimensions m_dimensions;
159   const Dimensions& dimensions() const { return m_dimensions; }
160   FunctorExtractor(const TensorEvaluator<const TensorReductionOp<Op, Dims, ArgType, MakePointer_>, Device>& expr)
161   : m_dimensions(DimConstr<Dimensions, Evaluator::NumOutputDims>::getDim(expr.dimensions())) {}
162 };
163 
164 
165 template<typename Op, typename Dims, typename ArgType, template <class> class MakePointer_, typename Device>
166 struct FunctorExtractor<TensorEvaluator<TensorReductionOp<Op, Dims, ArgType, MakePointer_>, Device>>
167 : FunctorExtractor<TensorEvaluator<const TensorReductionOp<Op, Dims, ArgType, MakePointer_>, Device>>{};
168 /// template deduction function for FunctorExtractor
169 template <typename Evaluator>
170 auto inline extractFunctors(const Evaluator& evaluator)-> FunctorExtractor<Evaluator> {
171   return FunctorExtractor<Evaluator>(evaluator);
172 }
173 }  // namespace internal
174 }  // namespace TensorSycl
175 }  // namespace Eigen
176 
177 #endif  // UNSUPPORTED_EIGEN_CXX11_SRC_TENSOR_TENSORSYCL_EXTRACT_FUNCTORS_HPP
178