1 // { dg-do compile } 2 3 enum { Aligned, RowMajor }; 4 enum { ReadOnlyAccessors }; 5 template <typename> struct K { 6 enum { value }; 7 }; 8 template <typename> struct traits; 9 template <typename T> struct traits<const T> : traits<T> {}; 10 struct A { 11 enum { has_write_access, value }; 12 }; 13 template <typename, int n> class array { 14 public: 15 int operator[](unsigned long p1) { return values[p1]; } 16 int values[n]; 17 }; 18 template <typename> struct I; 19 template <typename, int, template <class> class = I> class M; 20 template <typename, int, int, typename> class J; 21 template <typename, int> class N; 22 template <typename, typename> class D; 23 template <typename, typename, typename, typename> class TensorContractionOp; 24 template <long, typename> class TensorChippingOp; 25 class C; 26 template <typename DenseIndex, int NumDims> 27 struct K<array<DenseIndex, NumDims>> { 28 static const long value = NumDims; 29 }; 30 template <typename Scalar_, int NumIndices_, int Options_, typename IndexType_> 31 struct traits<J<Scalar_, NumIndices_, Options_, IndexType_>> { 32 typedef IndexType_ Index; 33 }; 34 template <typename PlainObjectType, int Options_, 35 template <class> class MakePointer_> 36 struct traits<M<PlainObjectType, Options_, MakePointer_>> 37 : traits<PlainObjectType> {}; 38 template <typename T> struct B { typedef T type; }; 39 template <typename Derived> class N<Derived, ReadOnlyAccessors> { 40 public: 41 typedef typename traits<Derived>::Index Index; 42 D<int, Derived> m_fn1(); 43 template <typename OtherDerived, typename Dimensions> 44 TensorContractionOp<Dimensions, Derived, const OtherDerived, int> 45 m_fn2(OtherDerived, Dimensions); 46 template <Index> TensorChippingOp<1, Derived> m_fn3(Index); 47 }; 48 template <typename Derived, int = A::value> 49 class N : public N<Derived, ReadOnlyAccessors> { 50 public: 51 template <typename DeviceType> C m_fn4(DeviceType); 52 }; 53 template <typename, typename> struct TensorEvaluator; 54 template <typename UnaryOp, typename ArgType, typename Device> 55 struct TensorEvaluator<const D<UnaryOp, ArgType>, Device> { 56 TensorEvaluator(D<UnaryOp, ArgType>, Device); 57 }; 58 template <typename, typename> class D { 59 public: 60 typedef typename B<D>::type Nested; 61 }; 62 template <typename Indices_, typename LeftArgType_, typename RightArgType_, 63 typename OutputKernelType_, typename Device_> 64 struct traits< 65 TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_, 66 RightArgType_, OutputKernelType_>, 67 Device_>> { 68 typedef Indices_ Indices; 69 typedef LeftArgType_ LeftArgType; 70 typedef RightArgType_ RightArgType; 71 typedef OutputKernelType_ OutputKernelType; 72 typedef Device_ Device; 73 }; 74 template <typename, typename LhsXprType, typename RhsXprType, typename> 75 class TensorContractionOp { 76 public: 77 typedef typename B<TensorContractionOp>::type Nested; 78 typename LhsXprType::Nested m_fn5(); 79 typename RhsXprType::Nested m_fn6(); 80 }; 81 template <typename Derived> struct TensorContractionEvaluatorBase { 82 typedef typename traits<Derived>::LeftArgType LeftArgType; 83 typedef typename traits<Derived>::RightArgType RightArgType; 84 typedef typename traits<Derived>::Device Device; 85 TensorContractionEvaluatorBase( 86 TensorContractionOp<typename traits<Derived>::Indices, LeftArgType, 87 RightArgType, 88 typename traits<Derived>::OutputKernelType> 89 p1, 90 Device p2) 91 : m_leftImpl(p1.m_fn6(), p2), m_rightImpl(p1.m_fn5(), p2) { 92 long nocontract_idx; 93 for (int i;; i++) { 94 bool contracting; 95 if (contracting) { 96 if (nocontract_idx < K<int>::value) 97 m_j_size = m_j_strides[nocontract_idx]; 98 nocontract_idx++; 99 } 100 } 101 } 102 array<long, 1> m_j_strides; 103 long m_j_size; 104 TensorEvaluator<RightArgType, Device> m_leftImpl; 105 TensorEvaluator<LeftArgType, Device> m_rightImpl; 106 }; 107 template <typename Indices, typename LeftArgType, typename RightArgType, 108 typename OutputKernelType, typename Device> 109 struct TensorEvaluator< 110 const TensorContractionOp<Indices, LeftArgType, RightArgType, 111 OutputKernelType>, 112 Device> 113 : TensorContractionEvaluatorBase<TensorEvaluator< 114 const TensorContractionOp<Indices, LeftArgType, RightArgType, 115 OutputKernelType>, 116 Device>> { 117 typedef TensorEvaluator Self; 118 typedef TensorContractionEvaluatorBase<Self> Base; 119 TensorEvaluator( 120 TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> 121 p1, 122 Device p2) 123 : Base(p1, p2) {} 124 }; 125 template <long DimId, typename XprType> 126 struct traits<TensorChippingOp<DimId, XprType>> : traits<XprType> {}; 127 template <long, typename XprType> 128 class TensorChippingOp : public N<TensorChippingOp<1, XprType>> { 129 public: 130 typedef typename B<TensorChippingOp>::type Nested; 131 }; 132 template <long DimId, typename ArgType, typename Device> 133 struct TensorEvaluator<const TensorChippingOp<DimId, ArgType>, Device> { 134 static const int NumInputDims = K<typename ArgType::Dimensions>::value; 135 array<long, NumInputDims> m_dimensions; 136 }; 137 template <long DimId, typename ArgType, typename Device> 138 struct TensorEvaluator<TensorChippingOp<DimId, ArgType>, Device> 139 : TensorEvaluator<const TensorChippingOp<1, ArgType>, Device> { 140 TensorEvaluator(TensorChippingOp<DimId, ArgType>, Device); 141 }; 142 template <typename, typename RhsXprType> class TensorAssignOp { 143 public: 144 TensorAssignOp(TensorChippingOp<0, const M<J<int, 3, 1, int>, 1>>, 145 RhsXprType); 146 TensorChippingOp<0, const M<J<int, 3, 1, int>, 1>> m_fn7(); 147 typename RhsXprType::Nested m_fn8(); 148 }; 149 template <typename LeftArgType, typename RightArgType, typename Device> 150 struct TensorEvaluator<const TensorAssignOp<LeftArgType, RightArgType>, 151 Device> { 152 TensorEvaluator(TensorAssignOp<LeftArgType, RightArgType> p1, Device p2) 153 : m_leftImpl(p1.m_fn7(), p2), m_rightImpl(p1.m_fn8(), p2) {} 154 TensorEvaluator<LeftArgType, Device> m_leftImpl; 155 TensorEvaluator<RightArgType, Device> m_rightImpl; 156 }; 157 template <typename Expression> class F { 158 public: 159 static void m_fn9(Expression p1) { 160 int device; 161 TensorEvaluator<Expression, int>(p1, device); 162 } 163 }; 164 class C { 165 public: 166 void 167 operator=(TensorContractionOp<array<int, 1>, 168 TensorChippingOp<1, M<J<float, 3, 1, int>, 0>>, 169 const D<int, M<J<float, 3, 1, int>, 0>>, int> 170 p1) { 171 TensorAssignOp< 172 TensorChippingOp<0, const M<J<int, 3, 1, int>, 1>>, 173 const TensorContractionOp< 174 array<int, 1>, TensorChippingOp<1, M<J<float, 3, 1, int>, 0>>, 175 const D<int, M<J<float, 3, 1, int>, 0>>, int>> 176 assign(m_expression, p1); 177 F<const TensorAssignOp< 178 TensorChippingOp<0, const M<J<int, 3, 1, int>, 1>>, 179 const TensorContractionOp< 180 array<int, 1>, TensorChippingOp<1, M<J<float, 3, 1, int>, 0>>, 181 const D<int, M<J<float, 3, 1, int>, 0>>, int>>>::m_fn9(assign); 182 } 183 TensorChippingOp<0, const M<J<int, 3, 1, int>, 1>> m_expression; 184 }; 185 template <typename, int NumIndices_, int, typename> class J { 186 public: 187 typedef array<long, NumIndices_> Dimensions; 188 }; 189 template <typename PlainObjectType, int Options_, template <class> class> 190 class M : public N<M<PlainObjectType, Options_>> { 191 public: 192 typedef typename PlainObjectType::Dimensions Dimensions; 193 }; 194 template <int NDIMS> struct TTypes { 195 typedef M<J<float, NDIMS, RowMajor, int>, Aligned> ConstTensor; 196 }; 197 class L { 198 public: 199 template <typename, long NDIMS> typename TTypes<NDIMS>::ConstTensor m_fn10(); 200 }; 201 class H { 202 public: 203 H(int *); 204 }; 205 class G { 206 public: 207 G(H *(int *)); 208 }; 209 int Run_d; 210 class O : H { 211 public: 212 int BatchMatMul_context; 213 O() : H(&BatchMatMul_context) { 214 L out, in_y, in_x; 215 auto Tx = in_x.m_fn10<float, 3>(), Ty = in_y.m_fn10<float, 3>(), 216 Tz = out.m_fn10<float, 3>(), z = Tz; 217 array<int, 1> contract_pairs; 218 auto x = Tx.m_fn3<0>(0); 219 auto y = Ty.m_fn1(); 220 z.m_fn4(Run_d) = x.m_fn2(y, contract_pairs); 221 } 222 }; 223 G registrar__body__0__object([](int *) -> H * { O(); return 0; }); 224