1 //===- SymbolTable.h - MLIR Symbol Table Class ------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_IR_SYMBOLTABLE_H 10 #define MLIR_IR_SYMBOLTABLE_H 11 12 #include "mlir/IR/Attributes.h" 13 #include "mlir/IR/OpDefinition.h" 14 #include "llvm/ADT/SetVector.h" 15 #include "llvm/ADT/StringMap.h" 16 17 namespace mlir { 18 class Identifier; 19 class Operation; 20 21 /// This class allows for representing and managing the symbol table used by 22 /// operations with the 'SymbolTable' trait. Inserting into and erasing from 23 /// this SymbolTable will also insert and erase from the Operation given to it 24 /// at construction. 25 class SymbolTable { 26 public: 27 /// Build a symbol table with the symbols within the given operation. 28 SymbolTable(Operation *symbolTableOp); 29 30 /// Look up a symbol with the specified name, returning null if no such 31 /// name exists. Names never include the @ on them. 32 Operation *lookup(StringRef name) const; 33 template <typename T> lookup(StringRef name)34 T lookup(StringRef name) const { 35 return dyn_cast_or_null<T>(lookup(name)); 36 } 37 38 /// Look up a symbol with the specified name, returning null if no such 39 /// name exists. Names never include the @ on them. 40 Operation *lookup(StringAttr name) const; 41 template <typename T> lookup(StringAttr name)42 T lookup(StringAttr name) const { 43 return dyn_cast_or_null<T>(lookup(name)); 44 } 45 46 /// Erase the given symbol from the table. 47 void erase(Operation *symbol); 48 49 /// Insert a new symbol into the table, and rename it as necessary to avoid 50 /// collisions. Also insert at the specified location in the body of the 51 /// associated operation if it is not already there. It is asserted that the 52 /// symbol is not inside another operation. 53 void insert(Operation *symbol, Block::iterator insertPt = {}); 54 55 /// Return the name of the attribute used for symbol names. getSymbolAttrName()56 static StringRef getSymbolAttrName() { return "sym_name"; } 57 58 /// Returns the associated operation. getOp()59 Operation *getOp() const { return symbolTableOp; } 60 61 /// Return the name of the attribute used for symbol visibility. getVisibilityAttrName()62 static StringRef getVisibilityAttrName() { return "sym_visibility"; } 63 64 //===--------------------------------------------------------------------===// 65 // Symbol Utilities 66 //===--------------------------------------------------------------------===// 67 68 /// An enumeration detailing the different visibility types that a symbol may 69 /// have. 70 enum class Visibility { 71 /// The symbol is public and may be referenced anywhere internal or external 72 /// to the visible references in the IR. 73 Public, 74 75 /// The symbol is private and may only be referenced by SymbolRefAttrs local 76 /// to the operations within the current symbol table. 77 Private, 78 79 /// The symbol is visible to the current IR, which may include operations in 80 /// symbol tables above the one that owns the current symbol. `Nested` 81 /// visibility allows for referencing a symbol outside of its current symbol 82 /// table, while retaining the ability to observe all uses. 83 Nested, 84 }; 85 86 /// Returns the name of the given symbol operation, aborting if no symbol is 87 /// present. 88 static StringAttr getSymbolName(Operation *symbol); 89 90 /// Sets the name of the given symbol operation. 91 static void setSymbolName(Operation *symbol, StringAttr name); setSymbolName(Operation * symbol,StringRef name)92 static void setSymbolName(Operation *symbol, StringRef name) { 93 setSymbolName(symbol, StringAttr::get(symbol->getContext(), name)); 94 } 95 96 /// Returns the visibility of the given symbol operation. 97 static Visibility getSymbolVisibility(Operation *symbol); 98 /// Sets the visibility of the given symbol operation. 99 static void setSymbolVisibility(Operation *symbol, Visibility vis); 100 101 /// Returns the nearest symbol table from a given operation `from`. Returns 102 /// nullptr if no valid parent symbol table could be found. 103 static Operation *getNearestSymbolTable(Operation *from); 104 105 /// Walks all symbol table operations nested within, and including, `op`. For 106 /// each symbol table operation, the provided callback is invoked with the op 107 /// and a boolean signifying if the symbols within that symbol table can be 108 /// treated as if all uses within the IR are visible to the caller. 109 /// `allSymUsesVisible` identifies whether all of the symbol uses of symbols 110 /// within `op` are visible. 111 static void walkSymbolTables(Operation *op, bool allSymUsesVisible, 112 function_ref<void(Operation *, bool)> callback); 113 114 /// Returns the operation registered with the given symbol name with the 115 /// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation 116 /// with the 'OpTrait::SymbolTable' trait. 117 static Operation *lookupSymbolIn(Operation *op, StringAttr symbol); lookupSymbolIn(Operation * op,StringRef symbol)118 static Operation *lookupSymbolIn(Operation *op, StringRef symbol) { 119 return lookupSymbolIn(op, StringAttr::get(op->getContext(), symbol)); 120 } 121 static Operation *lookupSymbolIn(Operation *op, SymbolRefAttr symbol); 122 /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced 123 /// by a given SymbolRefAttr. Returns failure if any of the nested references 124 /// could not be resolved. 125 static LogicalResult lookupSymbolIn(Operation *op, SymbolRefAttr symbol, 126 SmallVectorImpl<Operation *> &symbols); 127 128 /// Returns the operation registered with the given symbol name within the 129 /// closest parent operation of, or including, 'from' with the 130 /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was 131 /// found. 132 static Operation *lookupNearestSymbolFrom(Operation *from, StringAttr symbol); 133 static Operation *lookupNearestSymbolFrom(Operation *from, 134 SymbolRefAttr symbol); 135 template <typename T> lookupNearestSymbolFrom(Operation * from,StringAttr symbol)136 static T lookupNearestSymbolFrom(Operation *from, StringAttr symbol) { 137 return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol)); 138 } 139 template <typename T> lookupNearestSymbolFrom(Operation * from,SymbolRefAttr symbol)140 static T lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol) { 141 return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol)); 142 } 143 144 /// This class represents a specific symbol use. 145 class SymbolUse { 146 public: SymbolUse(Operation * op,SymbolRefAttr symbolRef)147 SymbolUse(Operation *op, SymbolRefAttr symbolRef) 148 : owner(op), symbolRef(symbolRef) {} 149 150 /// Return the operation user of this symbol reference. getUser()151 Operation *getUser() const { return owner; } 152 153 /// Return the symbol reference that this use represents. getSymbolRef()154 SymbolRefAttr getSymbolRef() const { return symbolRef; } 155 156 private: 157 /// The operation that this access is held by. 158 Operation *owner; 159 160 /// The symbol reference that this use represents. 161 SymbolRefAttr symbolRef; 162 }; 163 164 /// This class implements a range of SymbolRef uses. 165 class UseRange { 166 public: UseRange(std::vector<SymbolUse> && uses)167 UseRange(std::vector<SymbolUse> &&uses) : uses(std::move(uses)) {} 168 169 using iterator = std::vector<SymbolUse>::const_iterator; begin()170 iterator begin() const { return uses.begin(); } end()171 iterator end() const { return uses.end(); } empty()172 bool empty() const { return uses.empty(); } 173 174 private: 175 std::vector<SymbolUse> uses; 176 }; 177 178 /// Get an iterator range for all of the uses, for any symbol, that are nested 179 /// within the given operation 'from'. This does not traverse into any nested 180 /// symbol tables. This function returns None if there are any unknown 181 /// operations that may potentially be symbol tables. 182 static Optional<UseRange> getSymbolUses(Operation *from); 183 static Optional<UseRange> getSymbolUses(Region *from); 184 185 /// Get all of the uses of the given symbol that are nested within the given 186 /// operation 'from'. This does not traverse into any nested symbol tables. 187 /// This function returns None if there are any unknown operations that may 188 /// potentially be symbol tables. 189 static Optional<UseRange> getSymbolUses(StringAttr symbol, Operation *from); 190 static Optional<UseRange> getSymbolUses(Operation *symbol, Operation *from); 191 static Optional<UseRange> getSymbolUses(StringAttr symbol, Region *from); 192 static Optional<UseRange> getSymbolUses(Operation *symbol, Region *from); 193 194 /// Return if the given symbol is known to have no uses that are nested 195 /// within the given operation 'from'. This does not traverse into any nested 196 /// symbol tables. This function will also return false if there are any 197 /// unknown operations that may potentially be symbol tables. This doesn't 198 /// necessarily mean that there are no uses, we just can't conservatively 199 /// prove it. 200 static bool symbolKnownUseEmpty(StringAttr symbol, Operation *from); 201 static bool symbolKnownUseEmpty(Operation *symbol, Operation *from); 202 static bool symbolKnownUseEmpty(StringAttr symbol, Region *from); 203 static bool symbolKnownUseEmpty(Operation *symbol, Region *from); 204 205 /// Attempt to replace all uses of the given symbol 'oldSymbol' with the 206 /// provided symbol 'newSymbol' that are nested within the given operation 207 /// 'from'. This does not traverse into any nested symbol tables. If there are 208 /// any unknown operations that may potentially be symbol tables, no uses are 209 /// replaced and failure is returned. 210 static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, 211 StringAttr newSymbol, 212 Operation *from); 213 static LogicalResult replaceAllSymbolUses(Operation *oldSymbol, 214 StringAttr newSymbolName, 215 Operation *from); 216 static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, 217 StringAttr newSymbol, Region *from); 218 static LogicalResult replaceAllSymbolUses(Operation *oldSymbol, 219 StringAttr newSymbolName, 220 Region *from); 221 222 private: 223 Operation *symbolTableOp; 224 225 /// This is a mapping from a name to the symbol with that name. They key is 226 /// always known to be a StringAttr. 227 DenseMap<Attribute, Operation *> symbolTable; 228 229 /// This is used when name conflicts are detected. 230 unsigned uniquingCounter = 0; 231 }; 232 233 raw_ostream &operator<<(raw_ostream &os, SymbolTable::Visibility visibility); 234 235 //===----------------------------------------------------------------------===// 236 // SymbolTableCollection 237 //===----------------------------------------------------------------------===// 238 239 /// This class represents a collection of `SymbolTable`s. This simplifies 240 /// certain algorithms that run recursively on nested symbol tables. Symbol 241 /// tables are constructed lazily to reduce the upfront cost of constructing 242 /// unnecessary tables. 243 class SymbolTableCollection { 244 public: 245 /// Look up a symbol with the specified name within the specified symbol table 246 /// operation, returning null if no such name exists. 247 Operation *lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol); 248 Operation *lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name); 249 template <typename T, typename NameT> lookupSymbolIn(Operation * symbolTableOp,NameT && name)250 T lookupSymbolIn(Operation *symbolTableOp, NameT &&name) const { 251 return dyn_cast_or_null<T>( 252 lookupSymbolIn(symbolTableOp, std::forward<NameT>(name))); 253 } 254 /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced 255 /// by a given SymbolRefAttr when resolved within the provided symbol table 256 /// operation. Returns failure if any of the nested references could not be 257 /// resolved. 258 LogicalResult lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name, 259 SmallVectorImpl<Operation *> &symbols); 260 261 /// Returns the operation registered with the given symbol name within the 262 /// closest parent operation of, or including, 'from' with the 263 /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was 264 /// found. 265 Operation *lookupNearestSymbolFrom(Operation *from, StringAttr symbol); 266 Operation *lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol); 267 template <typename T> lookupNearestSymbolFrom(Operation * from,StringAttr symbol)268 T lookupNearestSymbolFrom(Operation *from, StringAttr symbol) { 269 return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol)); 270 } 271 template <typename T> lookupNearestSymbolFrom(Operation * from,SymbolRefAttr symbol)272 T lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol) { 273 return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol)); 274 } 275 276 /// Lookup, or create, a symbol table for an operation. 277 SymbolTable &getSymbolTable(Operation *op); 278 279 private: 280 /// The constructed symbol tables nested within this table. 281 DenseMap<Operation *, std::unique_ptr<SymbolTable>> symbolTables; 282 }; 283 284 //===----------------------------------------------------------------------===// 285 // SymbolUserMap 286 //===----------------------------------------------------------------------===// 287 288 /// This class represents a map of symbols to users, and provides efficient 289 /// implementations of symbol queries related to users; such as collecting the 290 /// users of a symbol, replacing all uses, etc. 291 class SymbolUserMap { 292 public: 293 /// Build a user map for all of the symbols defined in regions nested under 294 /// 'symbolTableOp'. A reference to the provided symbol table collection is 295 /// kept by the user map to ensure efficient lookups, thus the lifetime should 296 /// extend beyond that of this map. 297 SymbolUserMap(SymbolTableCollection &symbolTable, Operation *symbolTableOp); 298 299 /// Return the users of the provided symbol operation. getUsers(Operation * symbol)300 ArrayRef<Operation *> getUsers(Operation *symbol) const { 301 auto it = symbolToUsers.find(symbol); 302 return it != symbolToUsers.end() ? it->second.getArrayRef() : llvm::None; 303 } 304 305 /// Return true if the given symbol has no uses. use_empty(Operation * symbol)306 bool use_empty(Operation *symbol) const { 307 return !symbolToUsers.count(symbol); 308 } 309 310 /// Replace all of the uses of the given symbol with `newSymbolName`. 311 void replaceAllUsesWith(Operation *symbol, StringAttr newSymbolName); 312 313 private: 314 /// A reference to the symbol table used to construct this map. 315 SymbolTableCollection &symbolTable; 316 317 /// A map of symbol operations to symbol users. 318 DenseMap<Operation *, SetVector<Operation *>> symbolToUsers; 319 }; 320 321 //===----------------------------------------------------------------------===// 322 // SymbolTable Trait Types 323 //===----------------------------------------------------------------------===// 324 325 namespace detail { 326 LogicalResult verifySymbolTable(Operation *op); 327 LogicalResult verifySymbol(Operation *op); 328 } // namespace detail 329 330 namespace OpTrait { 331 /// A trait used to provide symbol table functionalities to a region operation. 332 /// This operation must hold exactly 1 region. Once attached, all operations 333 /// that are directly within the region, i.e not including those within child 334 /// regions, that contain a 'SymbolTable::getSymbolAttrName()' StringAttr will 335 /// be verified to ensure that the names are uniqued. These operations must also 336 /// adhere to the constraints defined by the `Symbol` trait, even if they do not 337 /// inherit from it. 338 template <typename ConcreteType> 339 class SymbolTable : public TraitBase<ConcreteType, SymbolTable> { 340 public: verifyTrait(Operation * op)341 static LogicalResult verifyTrait(Operation *op) { 342 return ::mlir::detail::verifySymbolTable(op); 343 } 344 345 /// Look up a symbol with the specified name, returning null if no such 346 /// name exists. Symbol names never include the @ on them. Note: This 347 /// performs a linear scan of held symbols. lookupSymbol(StringAttr name)348 Operation *lookupSymbol(StringAttr name) { 349 return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), name); 350 } 351 template <typename T> lookupSymbol(StringAttr name)352 T lookupSymbol(StringAttr name) { 353 return dyn_cast_or_null<T>(lookupSymbol(name)); 354 } lookupSymbol(SymbolRefAttr symbol)355 Operation *lookupSymbol(SymbolRefAttr symbol) { 356 return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), symbol); 357 } 358 template <typename T> lookupSymbol(SymbolRefAttr symbol)359 T lookupSymbol(SymbolRefAttr symbol) { 360 return dyn_cast_or_null<T>(lookupSymbol(symbol)); 361 } 362 lookupSymbol(StringRef name)363 Operation *lookupSymbol(StringRef name) { 364 return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), name); 365 } 366 template <typename T> lookupSymbol(StringRef name)367 T lookupSymbol(StringRef name) { 368 return dyn_cast_or_null<T>(lookupSymbol(name)); 369 } 370 }; 371 372 } // end namespace OpTrait 373 374 //===----------------------------------------------------------------------===// 375 // Visibility parsing implementation. 376 //===----------------------------------------------------------------------===// 377 378 namespace impl { 379 /// Parse an optional visibility attribute keyword (i.e., public, private, or 380 /// nested) without quotes in a string attribute named 'attrName'. 381 ParseResult parseOptionalVisibilityKeyword(OpAsmParser &parser, 382 NamedAttrList &attrs); 383 } // end namespace impl 384 385 } // end namespace mlir 386 387 /// Include the generated symbol interfaces. 388 #include "mlir/IR/SymbolInterfaces.h.inc" 389 390 #endif // MLIR_IR_SYMBOLTABLE_H 391