1 //===- Helpers.h - MLIR Declarative Helper Functionality --------*- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Provides helper classes and syntactic sugar for declarative builders.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_EDSC_HELPERS_H_
14 #define MLIR_EDSC_HELPERS_H_
15 
16 #include "mlir/EDSC/Builders.h"
17 #include "mlir/EDSC/Intrinsics.h"
18 
19 namespace mlir {
20 namespace edsc {
21 
22 // A TemplatedIndexedValue brings an index notation over the template Load and
23 // Store parameters.
24 template <typename Load, typename Store> class TemplatedIndexedValue;
25 
26 // By default, edsc::IndexedValue provides an index notation around the affine
27 // load and stores. edsc::StdIndexedValue provides the standard load/store
28 // counterpart.
29 using IndexedValue =
30     TemplatedIndexedValue<intrinsics::affine_load, intrinsics::affine_store>;
31 using StdIndexedValue =
32     TemplatedIndexedValue<intrinsics::std_load, intrinsics::std_store>;
33 
34 // Base class for MemRefView and VectorView.
35 class View {
36 public:
rank()37   unsigned rank() const { return lbs.size(); }
lb(unsigned idx)38   ValueHandle lb(unsigned idx) { return lbs[idx]; }
ub(unsigned idx)39   ValueHandle ub(unsigned idx) { return ubs[idx]; }
step(unsigned idx)40   int64_t step(unsigned idx) { return steps[idx]; }
range(unsigned idx)41   std::tuple<ValueHandle, ValueHandle, int64_t> range(unsigned idx) {
42     return std::make_tuple(lbs[idx], ubs[idx], steps[idx]);
43   }
swapRanges(unsigned i,unsigned j)44   void swapRanges(unsigned i, unsigned j) {
45     if (i == j)
46       return;
47     lbs[i].swap(lbs[j]);
48     ubs[i].swap(ubs[j]);
49     std::swap(steps[i], steps[j]);
50   }
51 
getLbs()52   ArrayRef<ValueHandle> getLbs() { return lbs; }
getUbs()53   ArrayRef<ValueHandle> getUbs() { return ubs; }
getSteps()54   ArrayRef<int64_t> getSteps() { return steps; }
55 
56 protected:
57   SmallVector<ValueHandle, 8> lbs;
58   SmallVector<ValueHandle, 8> ubs;
59   SmallVector<int64_t, 8> steps;
60 };
61 
62 /// A MemRefView represents the information required to step through a
63 /// MemRef. It has placeholders for non-contiguous tensors that fit within the
64 /// Fortran subarray model.
65 /// At the moment it can only capture a MemRef with an identity layout map.
66 // TODO(ntv): Support MemRefs with layoutMaps.
67 class MemRefView : public View {
68 public:
69   explicit MemRefView(Value v);
70   MemRefView(const MemRefView &) = default;
71   MemRefView &operator=(const MemRefView &) = default;
72 
fastestVarying()73   unsigned fastestVarying() const { return rank() - 1; }
74 
75 private:
76   friend IndexedValue;
77   ValueHandle base;
78 };
79 
80 /// A VectorView represents the information required to step through a
81 /// Vector accessing each scalar element at a time. It is the counterpart of
82 /// a MemRefView but for vectors. This exists purely for boilerplate avoidance.
83 class VectorView : public View {
84 public:
85   explicit VectorView(Value v);
86   VectorView(const VectorView &) = default;
87   VectorView &operator=(const VectorView &) = default;
88 
89 private:
90   friend IndexedValue;
91   ValueHandle base;
92 };
93 
94 /// A TemplatedIndexedValue brings an index notation over the template Load and
95 /// Store parameters. This helper class is an abstraction purely for sugaring
96 /// purposes and allows writing compact expressions such as:
97 ///
98 /// ```mlir
99 ///    // `IndexedValue` provided by default in the mlir::edsc namespace.
100 ///    using IndexedValue =
101 ///      TemplatedIndexedValue<intrinsics::load, intrinsics::store>;
102 ///    IndexedValue A(...), B(...), C(...);
103 ///    For(ivs, zeros, shapeA, ones, {
104 ///      C(ivs) = A(ivs) + B(ivs)
105 ///    });
106 /// ```
107 ///
108 /// Assigning to an IndexedValue emits an actual `Store` operation, while
109 /// converting an IndexedValue to a ValueHandle emits an actual `Load`
110 /// operation.
111 template <typename Load, typename Store> class TemplatedIndexedValue {
112 public:
TemplatedIndexedValue(Type t)113   explicit TemplatedIndexedValue(Type t) : base(t) {}
TemplatedIndexedValue(Value v)114   explicit TemplatedIndexedValue(Value v)
115       : TemplatedIndexedValue(ValueHandle(v)) {}
TemplatedIndexedValue(ValueHandle v)116   explicit TemplatedIndexedValue(ValueHandle v) : base(v) {}
117 
118   TemplatedIndexedValue(const TemplatedIndexedValue &rhs) = default;
119 
operator()120   TemplatedIndexedValue operator()() { return *this; }
121   /// Returns a new `TemplatedIndexedValue`.
operator()122   TemplatedIndexedValue operator()(ValueHandle index) {
123     TemplatedIndexedValue res(base);
124     res.indices.push_back(index);
125     return res;
126   }
127   template <typename... Args>
operator()128   TemplatedIndexedValue operator()(ValueHandle index, Args... indices) {
129     return TemplatedIndexedValue(base, index).append(indices...);
130   }
operator()131   TemplatedIndexedValue operator()(ArrayRef<ValueHandle> indices) {
132     return TemplatedIndexedValue(base, indices);
133   }
operator()134   TemplatedIndexedValue operator()(ArrayRef<IndexHandle> indices) {
135     return TemplatedIndexedValue(
136         base, ArrayRef<ValueHandle>(indices.begin(), indices.end()));
137   }
138 
139   /// Emits a `store`.
140   // NOLINTNEXTLINE: unconventional-assign-operator
141   OperationHandle operator=(const TemplatedIndexedValue &rhs) {
142     ValueHandle rrhs(rhs);
143     return Store(rrhs, getBase(), {indices.begin(), indices.end()});
144   }
145   // NOLINTNEXTLINE: unconventional-assign-operator
146   OperationHandle operator=(ValueHandle rhs) {
147     return Store(rhs, getBase(), {indices.begin(), indices.end()});
148   }
149 
150   /// Emits a `load` when converting to a ValueHandle.
ValueHandle()151   operator ValueHandle() const {
152     return Load(getBase(), {indices.begin(), indices.end()});
153   }
154 
155   /// Emits a `load` when converting to a Value.
156   Value operator*(void) const {
157     return Load(getBase(), {indices.begin(), indices.end()}).getValue();
158   }
159 
getBase()160   ValueHandle getBase() const { return base; }
161 
162   /// Operator overloadings.
163   ValueHandle operator+(ValueHandle e);
164   ValueHandle operator-(ValueHandle e);
165   ValueHandle operator*(ValueHandle e);
166   ValueHandle operator/(ValueHandle e);
167   OperationHandle operator+=(ValueHandle e);
168   OperationHandle operator-=(ValueHandle e);
169   OperationHandle operator*=(ValueHandle e);
170   OperationHandle operator/=(ValueHandle e);
171   ValueHandle operator+(TemplatedIndexedValue e) {
172     return *this + static_cast<ValueHandle>(e);
173   }
174   ValueHandle operator-(TemplatedIndexedValue e) {
175     return *this - static_cast<ValueHandle>(e);
176   }
177   ValueHandle operator*(TemplatedIndexedValue e) {
178     return *this * static_cast<ValueHandle>(e);
179   }
180   ValueHandle operator/(TemplatedIndexedValue e) {
181     return *this / static_cast<ValueHandle>(e);
182   }
183   OperationHandle operator+=(TemplatedIndexedValue e) {
184     return this->operator+=(static_cast<ValueHandle>(e));
185   }
186   OperationHandle operator-=(TemplatedIndexedValue e) {
187     return this->operator-=(static_cast<ValueHandle>(e));
188   }
189   OperationHandle operator*=(TemplatedIndexedValue e) {
190     return this->operator*=(static_cast<ValueHandle>(e));
191   }
192   OperationHandle operator/=(TemplatedIndexedValue e) {
193     return this->operator/=(static_cast<ValueHandle>(e));
194   }
195 
196 private:
TemplatedIndexedValue(ValueHandle base,ArrayRef<ValueHandle> indices)197   TemplatedIndexedValue(ValueHandle base, ArrayRef<ValueHandle> indices)
198       : base(base), indices(indices.begin(), indices.end()) {}
199 
append()200   TemplatedIndexedValue &append() { return *this; }
201 
202   template <typename T, typename... Args>
append(T index,Args...indices)203   TemplatedIndexedValue &append(T index, Args... indices) {
204     this->indices.push_back(static_cast<ValueHandle>(index));
205     append(indices...);
206     return *this;
207   }
208   ValueHandle base;
209   SmallVector<ValueHandle, 8> indices;
210 };
211 
212 /// Operator overloadings.
213 template <typename Load, typename Store>
214 ValueHandle TemplatedIndexedValue<Load, Store>::operator+(ValueHandle e) {
215   using op::operator+;
216   return static_cast<ValueHandle>(*this) + e;
217 }
218 template <typename Load, typename Store>
219 ValueHandle TemplatedIndexedValue<Load, Store>::operator-(ValueHandle e) {
220   using op::operator-;
221   return static_cast<ValueHandle>(*this) - e;
222 }
223 template <typename Load, typename Store>
224 ValueHandle TemplatedIndexedValue<Load, Store>::operator*(ValueHandle e) {
225   using op::operator*;
226   return static_cast<ValueHandle>(*this) * e;
227 }
228 template <typename Load, typename Store>
229 ValueHandle TemplatedIndexedValue<Load, Store>::operator/(ValueHandle e) {
230   using op::operator/;
231   return static_cast<ValueHandle>(*this) / e;
232 }
233 
234 template <typename Load, typename Store>
235 OperationHandle TemplatedIndexedValue<Load, Store>::operator+=(ValueHandle e) {
236   using op::operator+;
237   return Store(*this + e, getBase(), {indices.begin(), indices.end()});
238 }
239 template <typename Load, typename Store>
240 OperationHandle TemplatedIndexedValue<Load, Store>::operator-=(ValueHandle e) {
241   using op::operator-;
242   return Store(*this - e, getBase(), {indices.begin(), indices.end()});
243 }
244 template <typename Load, typename Store>
245 OperationHandle TemplatedIndexedValue<Load, Store>::operator*=(ValueHandle e) {
246   using op::operator*;
247   return Store(*this * e, getBase(), {indices.begin(), indices.end()});
248 }
249 template <typename Load, typename Store>
250 OperationHandle TemplatedIndexedValue<Load, Store>::operator/=(ValueHandle e) {
251   using op::operator/;
252   return Store(*this / e, getBase(), {indices.begin(), indices.end()});
253 }
254 
255 } // namespace edsc
256 } // namespace mlir
257 
258 #endif // MLIR_EDSC_HELPERS_H_
259