1 //===- SymbolTable.h - MLIR Symbol Table Class ------------------*- C++ -*-===// 2 // 3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_IR_SYMBOLTABLE_H 10 #define MLIR_IR_SYMBOLTABLE_H 11 12 #include "mlir/IR/OpDefinition.h" 13 #include "llvm/ADT/StringMap.h" 14 15 namespace mlir { 16 class Identifier; 17 class Operation; 18 19 /// This class allows for representing and managing the symbol table used by 20 /// operations with the 'SymbolTable' trait. Inserting into and erasing from 21 /// this SymbolTable will also insert and erase from the Operation given to it 22 /// at construction. 23 class SymbolTable { 24 public: 25 /// Build a symbol table with the symbols within the given operation. 26 SymbolTable(Operation *symbolTableOp); 27 28 /// Look up a symbol with the specified name, returning null if no such 29 /// name exists. Names never include the @ on them. 30 Operation *lookup(StringRef name) const; lookup(StringRef name)31 template <typename T> T lookup(StringRef name) const { 32 return dyn_cast_or_null<T>(lookup(name)); 33 } 34 35 /// Erase the given symbol from the table. 36 void erase(Operation *symbol); 37 38 /// Insert a new symbol into the table, and rename it as necessary to avoid 39 /// collisions. Also insert at the specified location in the body of the 40 /// associated operation. 41 void insert(Operation *symbol, Block::iterator insertPt = {}); 42 43 /// Return the name of the attribute used for symbol names. getSymbolAttrName()44 static StringRef getSymbolAttrName() { return "sym_name"; } 45 46 /// Returns the associated operation. getOp()47 Operation *getOp() const { return symbolTableOp; } 48 49 /// Return the name of the attribute used for symbol visibility. getVisibilityAttrName()50 static StringRef getVisibilityAttrName() { return "sym_visibility"; } 51 52 //===--------------------------------------------------------------------===// 53 // Symbol Utilities 54 //===--------------------------------------------------------------------===// 55 56 /// An enumeration detailing the different visibility types that a symbol may 57 /// have. 58 enum class Visibility { 59 /// The symbol is public and may be referenced anywhere internal or external 60 /// to the visible references in the IR. 61 Public, 62 63 /// The symbol is private and may only be referenced by SymbolRefAttrs local 64 /// to the operations within the current symbol table. 65 Private, 66 67 /// The symbol is visible to the current IR, which may include operations in 68 /// symbol tables above the one that owns the current symbol. `Nested` 69 /// visibility allows for referencing a symbol outside of its current symbol 70 /// table, while retaining the ability to observe all uses. 71 Nested, 72 }; 73 74 /// Returns true if the given operation defines a symbol. 75 static bool isSymbol(Operation *op); 76 77 /// Returns the name of the given symbol operation. 78 static StringRef getSymbolName(Operation *symbol); 79 /// Sets the name of the given symbol operation. 80 static void setSymbolName(Operation *symbol, StringRef name); 81 82 /// Returns the visibility of the given symbol operation. 83 static Visibility getSymbolVisibility(Operation *symbol); 84 /// Sets the visibility of the given symbol operation. 85 static void setSymbolVisibility(Operation *symbol, Visibility vis); 86 87 /// Returns the operation registered with the given symbol name with the 88 /// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation 89 /// with the 'OpTrait::SymbolTable' trait. 90 static Operation *lookupSymbolIn(Operation *op, StringRef symbol); 91 static Operation *lookupSymbolIn(Operation *op, SymbolRefAttr symbol); 92 93 /// Returns the operation registered with the given symbol name within the 94 /// closest parent operation of, or including, 'from' with the 95 /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was 96 /// found. 97 static Operation *lookupNearestSymbolFrom(Operation *from, StringRef symbol); 98 static Operation *lookupNearestSymbolFrom(Operation *from, 99 SymbolRefAttr symbol); 100 101 /// This class represents a specific symbol use. 102 class SymbolUse { 103 public: SymbolUse(Operation * op,SymbolRefAttr symbolRef)104 SymbolUse(Operation *op, SymbolRefAttr symbolRef) 105 : owner(op), symbolRef(symbolRef) {} 106 107 /// Return the operation user of this symbol reference. getUser()108 Operation *getUser() const { return owner; } 109 110 /// Return the symbol reference that this use represents. getSymbolRef()111 SymbolRefAttr getSymbolRef() const { return symbolRef; } 112 113 private: 114 /// The operation that this access is held by. 115 Operation *owner; 116 117 /// The symbol reference that this use represents. 118 SymbolRefAttr symbolRef; 119 }; 120 121 /// This class implements a range of SymbolRef uses. 122 class UseRange { 123 public: UseRange(std::vector<SymbolUse> && uses)124 UseRange(std::vector<SymbolUse> &&uses) : uses(std::move(uses)) {} 125 126 using iterator = std::vector<SymbolUse>::const_iterator; begin()127 iterator begin() const { return uses.begin(); } end()128 iterator end() const { return uses.end(); } 129 130 private: 131 std::vector<SymbolUse> uses; 132 }; 133 134 /// Get an iterator range for all of the uses, for any symbol, that are nested 135 /// within the given operation 'from'. This does not traverse into any nested 136 /// symbol tables, and will also only return uses on 'from' if it does not 137 /// also define a symbol table. This is because we treat the region as the 138 /// boundary of the symbol table, and not the op itself. This function returns 139 /// None if there are any unknown operations that may potentially be symbol 140 /// tables. 141 static Optional<UseRange> getSymbolUses(Operation *from); 142 143 /// Get all of the uses of the given symbol that are nested within the given 144 /// operation 'from'. This does not traverse into any nested symbol tables, 145 /// and will also only return uses on 'from' if it does not also define a 146 /// symbol table. This is because we treat the region as the boundary of the 147 /// symbol table, and not the op itself. This function returns None if there 148 /// are any unknown operations that may potentially be symbol tables. 149 static Optional<UseRange> getSymbolUses(StringRef symbol, Operation *from); 150 static Optional<UseRange> getSymbolUses(Operation *symbol, Operation *from); 151 152 /// Return if the given symbol is known to have no uses that are nested 153 /// within the given operation 'from'. This does not traverse into any nested 154 /// symbol tables, and will also only count uses on 'from' if it does not also 155 /// define a symbol table. This is because we treat the region as the boundary 156 /// of the symbol table, and not the op itself. This function will also return 157 /// false if there are any unknown operations that may potentially be symbol 158 /// tables. This doesn't necessarily mean that there are no uses, we just 159 /// can't conservatively prove it. 160 static bool symbolKnownUseEmpty(StringRef symbol, Operation *from); 161 static bool symbolKnownUseEmpty(Operation *symbol, Operation *from); 162 163 /// Attempt to replace all uses of the given symbol 'oldSymbol' with the 164 /// provided symbol 'newSymbol' that are nested within the given operation 165 /// 'from'. This does not traverse into any nested symbol tables, and will 166 /// also only replace uses on 'from' if it does not also define a symbol 167 /// table. This is because we treat the region as the boundary of the symbol 168 /// table, and not the op itself. If there are any unknown operations that may 169 /// potentially be symbol tables, no uses are replaced and failure is 170 /// returned. 171 LLVM_NODISCARD static LogicalResult replaceAllSymbolUses(StringRef oldSymbol, 172 StringRef newSymbol, 173 Operation *from); 174 LLVM_NODISCARD static LogicalResult 175 replaceAllSymbolUses(Operation *oldSymbol, StringRef newSymbolName, 176 Operation *from); 177 178 private: 179 Operation *symbolTableOp; 180 181 /// This is a mapping from a name to the symbol with that name. 182 llvm::StringMap<Operation *> symbolTable; 183 184 /// This is used when name conflicts are detected. 185 unsigned uniquingCounter = 0; 186 }; 187 188 //===----------------------------------------------------------------------===// 189 // SymbolTable Trait Types 190 //===----------------------------------------------------------------------===// 191 192 namespace OpTrait { 193 namespace impl { 194 LogicalResult verifySymbolTable(Operation *op); 195 LogicalResult verifySymbol(Operation *op); 196 } // namespace impl 197 198 /// A trait used to provide symbol table functionalities to a region operation. 199 /// This operation must hold exactly 1 region. Once attached, all operations 200 /// that are directly within the region, i.e not including those within child 201 /// regions, that contain a 'SymbolTable::getSymbolAttrName()' StringAttr will 202 /// be verified to ensure that the names are uniqued. These operations must also 203 /// adhere to the constraints defined by the `Symbol` trait, even if they do not 204 /// inherit from it. 205 template <typename ConcreteType> 206 class SymbolTable : public TraitBase<ConcreteType, SymbolTable> { 207 public: verifyTrait(Operation * op)208 static LogicalResult verifyTrait(Operation *op) { 209 return impl::verifySymbolTable(op); 210 } 211 212 /// Look up a symbol with the specified name, returning null if no such 213 /// name exists. Symbol names never include the @ on them. Note: This 214 /// performs a linear scan of held symbols. lookupSymbol(StringRef name)215 Operation *lookupSymbol(StringRef name) { 216 return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), name); 217 } lookupSymbol(StringRef name)218 template <typename T> T lookupSymbol(StringRef name) { 219 return dyn_cast_or_null<T>(lookupSymbol(name)); 220 } 221 }; 222 223 /// A trait used to define a symbol that can be used on operations within a 224 /// symbol table. Operations using this trait must adhere to the following: 225 /// * Have a StringAttr attribute named 'SymbolTable::getSymbolAttrName()'. 226 template <typename ConcreteType> 227 class Symbol : public TraitBase<ConcreteType, Symbol> { 228 public: 229 using Visibility = mlir::SymbolTable::Visibility; 230 verifyTrait(Operation * op)231 static LogicalResult verifyTrait(Operation *op) { 232 return impl::verifySymbol(op); 233 } 234 235 /// Returns the name of this symbol. getName()236 StringRef getName() { 237 return this->getOperation() 238 ->template getAttrOfType<StringAttr>( 239 mlir::SymbolTable::getSymbolAttrName()) 240 .getValue(); 241 } 242 243 /// Set the name of this symbol. setName(StringRef name)244 void setName(StringRef name) { 245 this->getOperation()->setAttr( 246 mlir::SymbolTable::getSymbolAttrName(), 247 StringAttr::get(name, this->getOperation()->getContext())); 248 } 249 250 /// Returns the visibility of the current symbol. getVisibility()251 Visibility getVisibility() { 252 return mlir::SymbolTable::getSymbolVisibility(this->getOperation()); 253 } 254 255 /// Sets the visibility of the current symbol. setVisibility(Visibility vis)256 void setVisibility(Visibility vis) { 257 mlir::SymbolTable::setSymbolVisibility(this->getOperation(), vis); 258 } 259 260 /// Get all of the uses of the current symbol that are nested within the given 261 /// operation 'from'. 262 /// Note: See mlir::SymbolTable::getSymbolUses for more details. getSymbolUses(Operation * from)263 Optional<::mlir::SymbolTable::UseRange> getSymbolUses(Operation *from) { 264 return ::mlir::SymbolTable::getSymbolUses(this->getOperation(), from); 265 } 266 267 /// Return if the current symbol is known to have no uses that are nested 268 /// within the given operation 'from'. 269 /// Note: See mlir::SymbolTable::symbolKnownUseEmpty for more details. symbolKnownUseEmpty(Operation * from)270 bool symbolKnownUseEmpty(Operation *from) { 271 return ::mlir::SymbolTable::symbolKnownUseEmpty(this->getOperation(), from); 272 } 273 274 /// Attempt to replace all uses of the current symbol with the provided symbol 275 /// 'newSymbol' that are nested within the given operation 'from'. 276 /// Note: See mlir::SymbolTable::replaceAllSymbolUses for more details. replaceAllSymbolUses(StringRef newSymbol,Operation * from)277 LLVM_NODISCARD LogicalResult replaceAllSymbolUses(StringRef newSymbol, 278 Operation *from) { 279 return ::mlir::SymbolTable::replaceAllSymbolUses(this->getOperation(), 280 newSymbol, from); 281 } 282 }; 283 284 } // end namespace OpTrait 285 } // end namespace mlir 286 287 #endif // MLIR_IR_SYMBOLTABLE_H 288