1 //===- Helpers.h - MLIR Declarative Helper Functionality --------*- C++ -*-===// 2 // 3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Provides helper classes and syntactic sugar for declarative builders. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_EDSC_HELPERS_H_ 14 #define MLIR_EDSC_HELPERS_H_ 15 16 #include "mlir/EDSC/Builders.h" 17 #include "mlir/EDSC/Intrinsics.h" 18 19 namespace mlir { 20 namespace edsc { 21 22 // A TemplatedIndexedValue brings an index notation over the template Load and 23 // Store parameters. 24 template <typename Load, typename Store> class TemplatedIndexedValue; 25 26 // By default, edsc::IndexedValue provides an index notation around the affine 27 // load and stores. edsc::StdIndexedValue provides the standard load/store 28 // counterpart. 29 using IndexedValue = 30 TemplatedIndexedValue<intrinsics::affine_load, intrinsics::affine_store>; 31 using StdIndexedValue = 32 TemplatedIndexedValue<intrinsics::std_load, intrinsics::std_store>; 33 34 // Base class for MemRefView and VectorView. 35 class View { 36 public: rank()37 unsigned rank() const { return lbs.size(); } lb(unsigned idx)38 ValueHandle lb(unsigned idx) { return lbs[idx]; } ub(unsigned idx)39 ValueHandle ub(unsigned idx) { return ubs[idx]; } step(unsigned idx)40 int64_t step(unsigned idx) { return steps[idx]; } range(unsigned idx)41 std::tuple<ValueHandle, ValueHandle, int64_t> range(unsigned idx) { 42 return std::make_tuple(lbs[idx], ubs[idx], steps[idx]); 43 } swapRanges(unsigned i,unsigned j)44 void swapRanges(unsigned i, unsigned j) { 45 if (i == j) 46 return; 47 lbs[i].swap(lbs[j]); 48 ubs[i].swap(ubs[j]); 49 std::swap(steps[i], steps[j]); 50 } 51 getLbs()52 ArrayRef<ValueHandle> getLbs() { return lbs; } getUbs()53 ArrayRef<ValueHandle> getUbs() { return ubs; } getSteps()54 ArrayRef<int64_t> getSteps() { return steps; } 55 56 protected: 57 SmallVector<ValueHandle, 8> lbs; 58 SmallVector<ValueHandle, 8> ubs; 59 SmallVector<int64_t, 8> steps; 60 }; 61 62 /// A MemRefView represents the information required to step through a 63 /// MemRef. It has placeholders for non-contiguous tensors that fit within the 64 /// Fortran subarray model. 65 /// At the moment it can only capture a MemRef with an identity layout map. 66 // TODO(ntv): Support MemRefs with layoutMaps. 67 class MemRefView : public View { 68 public: 69 explicit MemRefView(Value v); 70 MemRefView(const MemRefView &) = default; 71 MemRefView &operator=(const MemRefView &) = default; 72 fastestVarying()73 unsigned fastestVarying() const { return rank() - 1; } 74 75 private: 76 friend IndexedValue; 77 ValueHandle base; 78 }; 79 80 /// A VectorView represents the information required to step through a 81 /// Vector accessing each scalar element at a time. It is the counterpart of 82 /// a MemRefView but for vectors. This exists purely for boilerplate avoidance. 83 class VectorView : public View { 84 public: 85 explicit VectorView(Value v); 86 VectorView(const VectorView &) = default; 87 VectorView &operator=(const VectorView &) = default; 88 89 private: 90 friend IndexedValue; 91 ValueHandle base; 92 }; 93 94 /// A TemplatedIndexedValue brings an index notation over the template Load and 95 /// Store parameters. This helper class is an abstraction purely for sugaring 96 /// purposes and allows writing compact expressions such as: 97 /// 98 /// ```mlir 99 /// // `IndexedValue` provided by default in the mlir::edsc namespace. 100 /// using IndexedValue = 101 /// TemplatedIndexedValue<intrinsics::load, intrinsics::store>; 102 /// IndexedValue A(...), B(...), C(...); 103 /// For(ivs, zeros, shapeA, ones, { 104 /// C(ivs) = A(ivs) + B(ivs) 105 /// }); 106 /// ``` 107 /// 108 /// Assigning to an IndexedValue emits an actual `Store` operation, while 109 /// converting an IndexedValue to a ValueHandle emits an actual `Load` 110 /// operation. 111 template <typename Load, typename Store> class TemplatedIndexedValue { 112 public: TemplatedIndexedValue(Type t)113 explicit TemplatedIndexedValue(Type t) : base(t) {} TemplatedIndexedValue(Value v)114 explicit TemplatedIndexedValue(Value v) 115 : TemplatedIndexedValue(ValueHandle(v)) {} TemplatedIndexedValue(ValueHandle v)116 explicit TemplatedIndexedValue(ValueHandle v) : base(v) {} 117 118 TemplatedIndexedValue(const TemplatedIndexedValue &rhs) = default; 119 operator()120 TemplatedIndexedValue operator()() { return *this; } 121 /// Returns a new `TemplatedIndexedValue`. operator()122 TemplatedIndexedValue operator()(ValueHandle index) { 123 TemplatedIndexedValue res(base); 124 res.indices.push_back(index); 125 return res; 126 } 127 template <typename... Args> operator()128 TemplatedIndexedValue operator()(ValueHandle index, Args... indices) { 129 return TemplatedIndexedValue(base, index).append(indices...); 130 } operator()131 TemplatedIndexedValue operator()(ArrayRef<ValueHandle> indices) { 132 return TemplatedIndexedValue(base, indices); 133 } operator()134 TemplatedIndexedValue operator()(ArrayRef<IndexHandle> indices) { 135 return TemplatedIndexedValue( 136 base, ArrayRef<ValueHandle>(indices.begin(), indices.end())); 137 } 138 139 /// Emits a `store`. 140 // NOLINTNEXTLINE: unconventional-assign-operator 141 OperationHandle operator=(const TemplatedIndexedValue &rhs) { 142 ValueHandle rrhs(rhs); 143 return Store(rrhs, getBase(), {indices.begin(), indices.end()}); 144 } 145 // NOLINTNEXTLINE: unconventional-assign-operator 146 OperationHandle operator=(ValueHandle rhs) { 147 return Store(rhs, getBase(), {indices.begin(), indices.end()}); 148 } 149 150 /// Emits a `load` when converting to a ValueHandle. ValueHandle()151 operator ValueHandle() const { 152 return Load(getBase(), {indices.begin(), indices.end()}); 153 } 154 155 /// Emits a `load` when converting to a Value. 156 Value operator*(void) const { 157 return Load(getBase(), {indices.begin(), indices.end()}).getValue(); 158 } 159 getBase()160 ValueHandle getBase() const { return base; } 161 162 /// Operator overloadings. 163 ValueHandle operator+(ValueHandle e); 164 ValueHandle operator-(ValueHandle e); 165 ValueHandle operator*(ValueHandle e); 166 ValueHandle operator/(ValueHandle e); 167 OperationHandle operator+=(ValueHandle e); 168 OperationHandle operator-=(ValueHandle e); 169 OperationHandle operator*=(ValueHandle e); 170 OperationHandle operator/=(ValueHandle e); 171 ValueHandle operator+(TemplatedIndexedValue e) { 172 return *this + static_cast<ValueHandle>(e); 173 } 174 ValueHandle operator-(TemplatedIndexedValue e) { 175 return *this - static_cast<ValueHandle>(e); 176 } 177 ValueHandle operator*(TemplatedIndexedValue e) { 178 return *this * static_cast<ValueHandle>(e); 179 } 180 ValueHandle operator/(TemplatedIndexedValue e) { 181 return *this / static_cast<ValueHandle>(e); 182 } 183 OperationHandle operator+=(TemplatedIndexedValue e) { 184 return this->operator+=(static_cast<ValueHandle>(e)); 185 } 186 OperationHandle operator-=(TemplatedIndexedValue e) { 187 return this->operator-=(static_cast<ValueHandle>(e)); 188 } 189 OperationHandle operator*=(TemplatedIndexedValue e) { 190 return this->operator*=(static_cast<ValueHandle>(e)); 191 } 192 OperationHandle operator/=(TemplatedIndexedValue e) { 193 return this->operator/=(static_cast<ValueHandle>(e)); 194 } 195 196 private: TemplatedIndexedValue(ValueHandle base,ArrayRef<ValueHandle> indices)197 TemplatedIndexedValue(ValueHandle base, ArrayRef<ValueHandle> indices) 198 : base(base), indices(indices.begin(), indices.end()) {} 199 append()200 TemplatedIndexedValue &append() { return *this; } 201 202 template <typename T, typename... Args> append(T index,Args...indices)203 TemplatedIndexedValue &append(T index, Args... indices) { 204 this->indices.push_back(static_cast<ValueHandle>(index)); 205 append(indices...); 206 return *this; 207 } 208 ValueHandle base; 209 SmallVector<ValueHandle, 8> indices; 210 }; 211 212 /// Operator overloadings. 213 template <typename Load, typename Store> 214 ValueHandle TemplatedIndexedValue<Load, Store>::operator+(ValueHandle e) { 215 using op::operator+; 216 return static_cast<ValueHandle>(*this) + e; 217 } 218 template <typename Load, typename Store> 219 ValueHandle TemplatedIndexedValue<Load, Store>::operator-(ValueHandle e) { 220 using op::operator-; 221 return static_cast<ValueHandle>(*this) - e; 222 } 223 template <typename Load, typename Store> 224 ValueHandle TemplatedIndexedValue<Load, Store>::operator*(ValueHandle e) { 225 using op::operator*; 226 return static_cast<ValueHandle>(*this) * e; 227 } 228 template <typename Load, typename Store> 229 ValueHandle TemplatedIndexedValue<Load, Store>::operator/(ValueHandle e) { 230 using op::operator/; 231 return static_cast<ValueHandle>(*this) / e; 232 } 233 234 template <typename Load, typename Store> 235 OperationHandle TemplatedIndexedValue<Load, Store>::operator+=(ValueHandle e) { 236 using op::operator+; 237 return Store(*this + e, getBase(), {indices.begin(), indices.end()}); 238 } 239 template <typename Load, typename Store> 240 OperationHandle TemplatedIndexedValue<Load, Store>::operator-=(ValueHandle e) { 241 using op::operator-; 242 return Store(*this - e, getBase(), {indices.begin(), indices.end()}); 243 } 244 template <typename Load, typename Store> 245 OperationHandle TemplatedIndexedValue<Load, Store>::operator*=(ValueHandle e) { 246 using op::operator*; 247 return Store(*this * e, getBase(), {indices.begin(), indices.end()}); 248 } 249 template <typename Load, typename Store> 250 OperationHandle TemplatedIndexedValue<Load, Store>::operator/=(ValueHandle e) { 251 using op::operator/; 252 return Store(*this / e, getBase(), {indices.begin(), indices.end()}); 253 } 254 255 } // namespace edsc 256 } // namespace mlir 257 258 #endif // MLIR_EDSC_HELPERS_H_ 259