1 // { dg-do compile }
2 
3 enum { Aligned, RowMajor };
4 enum { ReadOnlyAccessors };
5 template <typename> struct K {
6   enum { value };
7 };
8 template <typename> struct traits;
9 template <typename T> struct traits<const T> : traits<T> {};
10 struct A {
11   enum { has_write_access, value };
12 };
13 template <typename, int n> class array {
14 public:
15   int operator[](unsigned long p1) { return values[p1]; }
16   int values[n];
17 };
18 template <typename> struct I;
19 template <typename, int, template <class> class = I> class M;
20 template <typename, int, int, typename> class J;
21 template <typename, int> class N;
22 template <typename, typename> class D;
23 template <typename, typename, typename, typename> class TensorContractionOp;
24 template <long, typename> class TensorChippingOp;
25 class C;
26 template <typename DenseIndex, int NumDims>
27 struct K<array<DenseIndex, NumDims>> {
28   static const long value = NumDims;
29 };
30 template <typename Scalar_, int NumIndices_, int Options_, typename IndexType_>
31 struct traits<J<Scalar_, NumIndices_, Options_, IndexType_>> {
32   typedef IndexType_ Index;
33 };
34 template <typename PlainObjectType, int Options_,
35           template <class> class MakePointer_>
36 struct traits<M<PlainObjectType, Options_, MakePointer_>>
37     : traits<PlainObjectType> {};
38 template <typename T> struct B { typedef T type; };
39 template <typename Derived> class N<Derived, ReadOnlyAccessors> {
40 public:
41   typedef typename traits<Derived>::Index Index;
42   D<int, Derived> m_fn1();
43   template <typename OtherDerived, typename Dimensions>
44   TensorContractionOp<Dimensions, Derived, const OtherDerived, int>
45       m_fn2(OtherDerived, Dimensions);
46   template <Index> TensorChippingOp<1, Derived> m_fn3(Index);
47 };
48 template <typename Derived, int = A::value>
49 class N : public N<Derived, ReadOnlyAccessors> {
50 public:
51   template <typename DeviceType> C m_fn4(DeviceType);
52 };
53 template <typename, typename> struct TensorEvaluator;
54 template <typename UnaryOp, typename ArgType, typename Device>
55 struct TensorEvaluator<const D<UnaryOp, ArgType>, Device> {
56   TensorEvaluator(D<UnaryOp, ArgType>, Device);
57 };
58 template <typename, typename> class D {
59 public:
60   typedef typename B<D>::type Nested;
61 };
62 template <typename Indices_, typename LeftArgType_, typename RightArgType_,
63           typename OutputKernelType_, typename Device_>
64 struct traits<
65     TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_,
66                                               RightArgType_, OutputKernelType_>,
67                     Device_>> {
68   typedef Indices_ Indices;
69   typedef LeftArgType_ LeftArgType;
70   typedef RightArgType_ RightArgType;
71   typedef OutputKernelType_ OutputKernelType;
72   typedef Device_ Device;
73 };
74 template <typename, typename LhsXprType, typename RhsXprType, typename>
75 class TensorContractionOp {
76 public:
77   typedef typename B<TensorContractionOp>::type Nested;
78   typename LhsXprType::Nested m_fn5();
79   typename RhsXprType::Nested m_fn6();
80 };
81 template <typename Derived> struct TensorContractionEvaluatorBase {
82   typedef typename traits<Derived>::LeftArgType LeftArgType;
83   typedef typename traits<Derived>::RightArgType RightArgType;
84   typedef typename traits<Derived>::Device Device;
85   TensorContractionEvaluatorBase(
86       TensorContractionOp<typename traits<Derived>::Indices, LeftArgType,
87                           RightArgType,
88                           typename traits<Derived>::OutputKernelType>
89           p1,
90       Device p2)
91       : m_leftImpl(p1.m_fn6(), p2), m_rightImpl(p1.m_fn5(), p2) {
92     long nocontract_idx;
93     for (int i;; i++) {
94       bool contracting;
95       if (contracting) {
96         if (nocontract_idx < K<int>::value)
97           m_j_size = m_j_strides[nocontract_idx];
98         nocontract_idx++;
99       }
100     }
101   }
102   array<long, 1> m_j_strides;
103   long m_j_size;
104   TensorEvaluator<RightArgType, Device> m_leftImpl;
105   TensorEvaluator<LeftArgType, Device> m_rightImpl;
106 };
107 template <typename Indices, typename LeftArgType, typename RightArgType,
108           typename OutputKernelType, typename Device>
109 struct TensorEvaluator<
110     const TensorContractionOp<Indices, LeftArgType, RightArgType,
111                               OutputKernelType>,
112     Device>
113     : TensorContractionEvaluatorBase<TensorEvaluator<
114           const TensorContractionOp<Indices, LeftArgType, RightArgType,
115                                     OutputKernelType>,
116           Device>> {
117   typedef TensorEvaluator Self;
118   typedef TensorContractionEvaluatorBase<Self> Base;
119   TensorEvaluator(
120       TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>
121           p1,
122       Device p2)
123       : Base(p1, p2) {}
124 };
125 template <long DimId, typename XprType>
126 struct traits<TensorChippingOp<DimId, XprType>> : traits<XprType> {};
127 template <long, typename XprType>
128 class TensorChippingOp : public N<TensorChippingOp<1, XprType>> {
129 public:
130   typedef typename B<TensorChippingOp>::type Nested;
131 };
132 template <long DimId, typename ArgType, typename Device>
133 struct TensorEvaluator<const TensorChippingOp<DimId, ArgType>, Device> {
134   static const int NumInputDims = K<typename ArgType::Dimensions>::value;
135   array<long, NumInputDims> m_dimensions;
136 };
137 template <long DimId, typename ArgType, typename Device>
138 struct TensorEvaluator<TensorChippingOp<DimId, ArgType>, Device>
139     : TensorEvaluator<const TensorChippingOp<1, ArgType>, Device> {
140   TensorEvaluator(TensorChippingOp<DimId, ArgType>, Device);
141 };
142 template <typename, typename RhsXprType> class TensorAssignOp {
143 public:
144   TensorAssignOp(TensorChippingOp<0, const M<J<int, 3, 1, int>, 1>>,
145                  RhsXprType);
146   TensorChippingOp<0, const M<J<int, 3, 1, int>, 1>> m_fn7();
147   typename RhsXprType::Nested m_fn8();
148 };
149 template <typename LeftArgType, typename RightArgType, typename Device>
150 struct TensorEvaluator<const TensorAssignOp<LeftArgType, RightArgType>,
151                        Device> {
152   TensorEvaluator(TensorAssignOp<LeftArgType, RightArgType> p1, Device p2)
153       : m_leftImpl(p1.m_fn7(), p2), m_rightImpl(p1.m_fn8(), p2) {}
154   TensorEvaluator<LeftArgType, Device> m_leftImpl;
155   TensorEvaluator<RightArgType, Device> m_rightImpl;
156 };
157 template <typename Expression> class F {
158 public:
159   static void m_fn9(Expression p1) {
160     int device;
161     TensorEvaluator<Expression, int>(p1, device);
162   }
163 };
164 class C {
165 public:
166   void
167   operator=(TensorContractionOp<array<int, 1>,
168                                 TensorChippingOp<1, M<J<float, 3, 1, int>, 0>>,
169                                 const D<int, M<J<float, 3, 1, int>, 0>>, int>
170                 p1) {
171     TensorAssignOp<
172         TensorChippingOp<0, const M<J<int, 3, 1, int>, 1>>,
173         const TensorContractionOp<
174             array<int, 1>, TensorChippingOp<1, M<J<float, 3, 1, int>, 0>>,
175             const D<int, M<J<float, 3, 1, int>, 0>>, int>>
176         assign(m_expression, p1);
177     F<const TensorAssignOp<
178         TensorChippingOp<0, const M<J<int, 3, 1, int>, 1>>,
179         const TensorContractionOp<
180             array<int, 1>, TensorChippingOp<1, M<J<float, 3, 1, int>, 0>>,
181             const D<int, M<J<float, 3, 1, int>, 0>>, int>>>::m_fn9(assign);
182   }
183   TensorChippingOp<0, const M<J<int, 3, 1, int>, 1>> m_expression;
184 };
185 template <typename, int NumIndices_, int, typename> class J {
186 public:
187   typedef array<long, NumIndices_> Dimensions;
188 };
189 template <typename PlainObjectType, int Options_, template <class> class>
190 class M : public N<M<PlainObjectType, Options_>> {
191 public:
192   typedef typename PlainObjectType::Dimensions Dimensions;
193 };
194 template <int NDIMS> struct TTypes {
195   typedef M<J<float, NDIMS, RowMajor, int>, Aligned> ConstTensor;
196 };
197 class L {
198 public:
199   template <typename, long NDIMS> typename TTypes<NDIMS>::ConstTensor m_fn10();
200 };
201 class H {
202 public:
203   H(int *);
204 };
205 class G {
206 public:
207   G(H *(int *));
208 };
209 int Run_d;
210 class O : H {
211 public:
212   int BatchMatMul_context;
213   O() : H(&BatchMatMul_context) {
214     L out, in_y, in_x;
215     auto Tx = in_x.m_fn10<float, 3>(), Ty = in_y.m_fn10<float, 3>(),
216          Tz = out.m_fn10<float, 3>(), z = Tz;
217     array<int, 1> contract_pairs;
218     auto x = Tx.m_fn3<0>(0);
219     auto y = Ty.m_fn1();
220     z.m_fn4(Run_d) = x.m_fn2(y, contract_pairs);
221   }
222 };
223 G registrar__body__0__object([](int *) -> H * { O(); return 0; });
224