1 //===- SymbolTable.h - MLIR Symbol Table Class ------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_IR_SYMBOLTABLE_H 10 #define MLIR_IR_SYMBOLTABLE_H 11 12 #include "mlir/IR/Attributes.h" 13 #include "mlir/IR/OpDefinition.h" 14 #include "llvm/ADT/SetVector.h" 15 #include "llvm/ADT/StringMap.h" 16 17 namespace mlir { 18 class Identifier; 19 class Operation; 20 21 /// This class allows for representing and managing the symbol table used by 22 /// operations with the 'SymbolTable' trait. Inserting into and erasing from 23 /// this SymbolTable will also insert and erase from the Operation given to it 24 /// at construction. 25 class SymbolTable { 26 public: 27 /// Build a symbol table with the symbols within the given operation. 28 SymbolTable(Operation *symbolTableOp); 29 30 /// Look up a symbol with the specified name, returning null if no such 31 /// name exists. Names never include the @ on them. 32 Operation *lookup(StringRef name) const; lookup(StringRef name)33 template <typename T> T lookup(StringRef name) const { 34 return dyn_cast_or_null<T>(lookup(name)); 35 } 36 37 /// Erase the given symbol from the table. 38 void erase(Operation *symbol); 39 40 /// Insert a new symbol into the table, and rename it as necessary to avoid 41 /// collisions. Also insert at the specified location in the body of the 42 /// associated operation if it is not already there. It is asserted that the 43 /// symbol is not inside another operation. 44 void insert(Operation *symbol, Block::iterator insertPt = {}); 45 46 /// Return the name of the attribute used for symbol names. getSymbolAttrName()47 static StringRef getSymbolAttrName() { return "sym_name"; } 48 49 /// Returns the associated operation. getOp()50 Operation *getOp() const { return symbolTableOp; } 51 52 /// Return the name of the attribute used for symbol visibility. getVisibilityAttrName()53 static StringRef getVisibilityAttrName() { return "sym_visibility"; } 54 55 //===--------------------------------------------------------------------===// 56 // Symbol Utilities 57 //===--------------------------------------------------------------------===// 58 59 /// An enumeration detailing the different visibility types that a symbol may 60 /// have. 61 enum class Visibility { 62 /// The symbol is public and may be referenced anywhere internal or external 63 /// to the visible references in the IR. 64 Public, 65 66 /// The symbol is private and may only be referenced by SymbolRefAttrs local 67 /// to the operations within the current symbol table. 68 Private, 69 70 /// The symbol is visible to the current IR, which may include operations in 71 /// symbol tables above the one that owns the current symbol. `Nested` 72 /// visibility allows for referencing a symbol outside of its current symbol 73 /// table, while retaining the ability to observe all uses. 74 Nested, 75 }; 76 77 /// Returns the name of the given symbol operation. 78 static StringRef getSymbolName(Operation *symbol); 79 /// Sets the name of the given symbol operation. 80 static void setSymbolName(Operation *symbol, StringRef name); 81 82 /// Returns the visibility of the given symbol operation. 83 static Visibility getSymbolVisibility(Operation *symbol); 84 /// Sets the visibility of the given symbol operation. 85 static void setSymbolVisibility(Operation *symbol, Visibility vis); 86 87 /// Returns the nearest symbol table from a given operation `from`. Returns 88 /// nullptr if no valid parent symbol table could be found. 89 static Operation *getNearestSymbolTable(Operation *from); 90 91 /// Walks all symbol table operations nested within, and including, `op`. For 92 /// each symbol table operation, the provided callback is invoked with the op 93 /// and a boolean signifying if the symbols within that symbol table can be 94 /// treated as if all uses within the IR are visible to the caller. 95 /// `allSymUsesVisible` identifies whether all of the symbol uses of symbols 96 /// within `op` are visible. 97 static void walkSymbolTables(Operation *op, bool allSymUsesVisible, 98 function_ref<void(Operation *, bool)> callback); 99 100 /// Returns the operation registered with the given symbol name with the 101 /// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation 102 /// with the 'OpTrait::SymbolTable' trait. 103 static Operation *lookupSymbolIn(Operation *op, StringRef symbol); 104 static Operation *lookupSymbolIn(Operation *op, SymbolRefAttr symbol); 105 /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced 106 /// by a given SymbolRefAttr. Returns failure if any of the nested references 107 /// could not be resolved. 108 static LogicalResult lookupSymbolIn(Operation *op, SymbolRefAttr symbol, 109 SmallVectorImpl<Operation *> &symbols); 110 111 /// Returns the operation registered with the given symbol name within the 112 /// closest parent operation of, or including, 'from' with the 113 /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was 114 /// found. 115 static Operation *lookupNearestSymbolFrom(Operation *from, StringRef symbol); 116 static Operation *lookupNearestSymbolFrom(Operation *from, 117 SymbolRefAttr symbol); 118 template <typename T> lookupNearestSymbolFrom(Operation * from,StringRef symbol)119 static T lookupNearestSymbolFrom(Operation *from, StringRef symbol) { 120 return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol)); 121 } 122 template <typename T> lookupNearestSymbolFrom(Operation * from,SymbolRefAttr symbol)123 static T lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol) { 124 return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol)); 125 } 126 127 /// This class represents a specific symbol use. 128 class SymbolUse { 129 public: SymbolUse(Operation * op,SymbolRefAttr symbolRef)130 SymbolUse(Operation *op, SymbolRefAttr symbolRef) 131 : owner(op), symbolRef(symbolRef) {} 132 133 /// Return the operation user of this symbol reference. getUser()134 Operation *getUser() const { return owner; } 135 136 /// Return the symbol reference that this use represents. getSymbolRef()137 SymbolRefAttr getSymbolRef() const { return symbolRef; } 138 139 private: 140 /// The operation that this access is held by. 141 Operation *owner; 142 143 /// The symbol reference that this use represents. 144 SymbolRefAttr symbolRef; 145 }; 146 147 /// This class implements a range of SymbolRef uses. 148 class UseRange { 149 public: UseRange(std::vector<SymbolUse> && uses)150 UseRange(std::vector<SymbolUse> &&uses) : uses(std::move(uses)) {} 151 152 using iterator = std::vector<SymbolUse>::const_iterator; begin()153 iterator begin() const { return uses.begin(); } end()154 iterator end() const { return uses.end(); } empty()155 bool empty() const { return uses.empty(); } 156 157 private: 158 std::vector<SymbolUse> uses; 159 }; 160 161 /// Get an iterator range for all of the uses, for any symbol, that are nested 162 /// within the given operation 'from'. This does not traverse into any nested 163 /// symbol tables. This function returns None if there are any unknown 164 /// operations that may potentially be symbol tables. 165 static Optional<UseRange> getSymbolUses(Operation *from); 166 static Optional<UseRange> getSymbolUses(Region *from); 167 168 /// Get all of the uses of the given symbol that are nested within the given 169 /// operation 'from'. This does not traverse into any nested symbol tables. 170 /// This function returns None if there are any unknown operations that may 171 /// potentially be symbol tables. 172 static Optional<UseRange> getSymbolUses(StringRef symbol, Operation *from); 173 static Optional<UseRange> getSymbolUses(Operation *symbol, Operation *from); 174 static Optional<UseRange> getSymbolUses(StringRef symbol, Region *from); 175 static Optional<UseRange> getSymbolUses(Operation *symbol, Region *from); 176 177 /// Return if the given symbol is known to have no uses that are nested 178 /// within the given operation 'from'. This does not traverse into any nested 179 /// symbol tables. This function will also return false if there are any 180 /// unknown operations that may potentially be symbol tables. This doesn't 181 /// necessarily mean that there are no uses, we just can't conservatively 182 /// prove it. 183 static bool symbolKnownUseEmpty(StringRef symbol, Operation *from); 184 static bool symbolKnownUseEmpty(Operation *symbol, Operation *from); 185 static bool symbolKnownUseEmpty(StringRef symbol, Region *from); 186 static bool symbolKnownUseEmpty(Operation *symbol, Region *from); 187 188 /// Attempt to replace all uses of the given symbol 'oldSymbol' with the 189 /// provided symbol 'newSymbol' that are nested within the given operation 190 /// 'from'. This does not traverse into any nested symbol tables. If there are 191 /// any unknown operations that may potentially be symbol tables, no uses are 192 /// replaced and failure is returned. 193 static LogicalResult replaceAllSymbolUses(StringRef oldSymbol, 194 StringRef newSymbol, 195 Operation *from); 196 static LogicalResult replaceAllSymbolUses(Operation *oldSymbol, 197 StringRef newSymbolName, 198 Operation *from); 199 static LogicalResult replaceAllSymbolUses(StringRef oldSymbol, 200 StringRef newSymbol, Region *from); 201 static LogicalResult replaceAllSymbolUses(Operation *oldSymbol, 202 StringRef newSymbolName, 203 Region *from); 204 205 private: 206 Operation *symbolTableOp; 207 208 /// This is a mapping from a name to the symbol with that name. 209 llvm::StringMap<Operation *> symbolTable; 210 211 /// This is used when name conflicts are detected. 212 unsigned uniquingCounter = 0; 213 }; 214 215 raw_ostream &operator<<(raw_ostream &os, SymbolTable::Visibility visibility); 216 217 //===----------------------------------------------------------------------===// 218 // SymbolTableCollection 219 //===----------------------------------------------------------------------===// 220 221 /// This class represents a collection of `SymbolTable`s. This simplifies 222 /// certain algorithms that run recursively on nested symbol tables. Symbol 223 /// tables are constructed lazily to reduce the upfront cost of constructing 224 /// unnecessary tables. 225 class SymbolTableCollection { 226 public: 227 /// Look up a symbol with the specified name within the specified symbol table 228 /// operation, returning null if no such name exists. 229 Operation *lookupSymbolIn(Operation *symbolTableOp, StringRef symbol); 230 Operation *lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name); 231 template <typename T, typename NameT> lookupSymbolIn(Operation * symbolTableOp,NameT && name)232 T lookupSymbolIn(Operation *symbolTableOp, NameT &&name) const { 233 return dyn_cast_or_null<T>( 234 lookupSymbolIn(symbolTableOp, std::forward<NameT>(name))); 235 } 236 /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced 237 /// by a given SymbolRefAttr when resolved within the provided symbol table 238 /// operation. Returns failure if any of the nested references could not be 239 /// resolved. 240 LogicalResult lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name, 241 SmallVectorImpl<Operation *> &symbols); 242 243 /// Returns the operation registered with the given symbol name within the 244 /// closest parent operation of, or including, 'from' with the 245 /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was 246 /// found. 247 Operation *lookupNearestSymbolFrom(Operation *from, StringRef symbol); 248 Operation *lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol); 249 template <typename T> lookupNearestSymbolFrom(Operation * from,StringRef symbol)250 T lookupNearestSymbolFrom(Operation *from, StringRef symbol) { 251 return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol)); 252 } 253 template <typename T> lookupNearestSymbolFrom(Operation * from,SymbolRefAttr symbol)254 T lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol) { 255 return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol)); 256 } 257 258 /// Lookup, or create, a symbol table for an operation. 259 SymbolTable &getSymbolTable(Operation *op); 260 261 private: 262 /// The constructed symbol tables nested within this table. 263 DenseMap<Operation *, std::unique_ptr<SymbolTable>> symbolTables; 264 }; 265 266 //===----------------------------------------------------------------------===// 267 // SymbolUserMap 268 //===----------------------------------------------------------------------===// 269 270 /// This class represents a map of symbols to users, and provides efficient 271 /// implementations of symbol queries related to users; such as collecting the 272 /// users of a symbol, replacing all uses, etc. 273 class SymbolUserMap { 274 public: 275 /// Build a user map for all of the symbols defined in regions nested under 276 /// 'symbolTableOp'. A reference to the provided symbol table collection is 277 /// kept by the user map to ensure efficient lookups, thus the lifetime should 278 /// extend beyond that of this map. 279 SymbolUserMap(SymbolTableCollection &symbolTable, Operation *symbolTableOp); 280 281 /// Return the users of the provided symbol operation. getUsers(Operation * symbol)282 ArrayRef<Operation *> getUsers(Operation *symbol) const { 283 auto it = symbolToUsers.find(symbol); 284 return it != symbolToUsers.end() ? it->second.getArrayRef() : llvm::None; 285 } 286 287 /// Return true if the given symbol has no uses. use_empty(Operation * symbol)288 bool use_empty(Operation *symbol) const { 289 return !symbolToUsers.count(symbol); 290 } 291 292 /// Replace all of the uses of the given symbol with `newSymbolName`. 293 void replaceAllUsesWith(Operation *symbol, StringRef newSymbolName); 294 295 private: 296 /// A reference to the symbol table used to construct this map. 297 SymbolTableCollection &symbolTable; 298 299 /// A map of symbol operations to symbol users. 300 DenseMap<Operation *, SetVector<Operation *>> symbolToUsers; 301 }; 302 303 //===----------------------------------------------------------------------===// 304 // SymbolTable Trait Types 305 //===----------------------------------------------------------------------===// 306 307 namespace detail { 308 LogicalResult verifySymbolTable(Operation *op); 309 LogicalResult verifySymbol(Operation *op); 310 } // namespace detail 311 312 namespace OpTrait { 313 /// A trait used to provide symbol table functionalities to a region operation. 314 /// This operation must hold exactly 1 region. Once attached, all operations 315 /// that are directly within the region, i.e not including those within child 316 /// regions, that contain a 'SymbolTable::getSymbolAttrName()' StringAttr will 317 /// be verified to ensure that the names are uniqued. These operations must also 318 /// adhere to the constraints defined by the `Symbol` trait, even if they do not 319 /// inherit from it. 320 template <typename ConcreteType> 321 class SymbolTable : public TraitBase<ConcreteType, SymbolTable> { 322 public: verifyTrait(Operation * op)323 static LogicalResult verifyTrait(Operation *op) { 324 return ::mlir::detail::verifySymbolTable(op); 325 } 326 327 /// Look up a symbol with the specified name, returning null if no such 328 /// name exists. Symbol names never include the @ on them. Note: This 329 /// performs a linear scan of held symbols. lookupSymbol(StringRef name)330 Operation *lookupSymbol(StringRef name) { 331 return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), name); 332 } lookupSymbol(StringRef name)333 template <typename T> T lookupSymbol(StringRef name) { 334 return dyn_cast_or_null<T>(lookupSymbol(name)); 335 } lookupSymbol(SymbolRefAttr symbol)336 Operation *lookupSymbol(SymbolRefAttr symbol) { 337 return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), symbol); 338 } lookupSymbol(SymbolRefAttr symbol)339 template <typename T> T lookupSymbol(SymbolRefAttr symbol) { 340 return dyn_cast_or_null<T>(lookupSymbol(symbol)); 341 } 342 }; 343 344 } // end namespace OpTrait 345 346 //===----------------------------------------------------------------------===// 347 // Visibility parsing implementation. 348 //===----------------------------------------------------------------------===// 349 350 namespace impl { 351 /// Parse an optional visibility attribute keyword (i.e., public, private, or 352 /// nested) without quotes in a string attribute named 'attrName'. 353 ParseResult parseOptionalVisibilityKeyword(OpAsmParser &parser, 354 NamedAttrList &attrs); 355 } // end namespace impl 356 357 } // end namespace mlir 358 359 /// Include the generated symbol interfaces. 360 #include "mlir/IR/SymbolInterfaces.h.inc" 361 362 #endif // MLIR_IR_SYMBOLTABLE_H 363