1 //===- SymbolTable.h - MLIR Symbol Table Class ------------------*- C++ -*-===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_IR_SYMBOLTABLE_H
10 #define MLIR_IR_SYMBOLTABLE_H
11 
12 #include "mlir/IR/OpDefinition.h"
13 #include "llvm/ADT/StringMap.h"
14 
15 namespace mlir {
16 class Identifier;
17 class Operation;
18 
19 /// This class allows for representing and managing the symbol table used by
20 /// operations with the 'SymbolTable' trait. Inserting into and erasing from
21 /// this SymbolTable will also insert and erase from the Operation given to it
22 /// at construction.
23 class SymbolTable {
24 public:
25   /// Build a symbol table with the symbols within the given operation.
26   SymbolTable(Operation *symbolTableOp);
27 
28   /// Look up a symbol with the specified name, returning null if no such
29   /// name exists. Names never include the @ on them.
30   Operation *lookup(StringRef name) const;
lookup(StringRef name)31   template <typename T> T lookup(StringRef name) const {
32     return dyn_cast_or_null<T>(lookup(name));
33   }
34 
35   /// Erase the given symbol from the table.
36   void erase(Operation *symbol);
37 
38   /// Insert a new symbol into the table, and rename it as necessary to avoid
39   /// collisions. Also insert at the specified location in the body of the
40   /// associated operation.
41   void insert(Operation *symbol, Block::iterator insertPt = {});
42 
43   /// Return the name of the attribute used for symbol names.
getSymbolAttrName()44   static StringRef getSymbolAttrName() { return "sym_name"; }
45 
46   /// Returns the associated operation.
getOp()47   Operation *getOp() const { return symbolTableOp; }
48 
49   /// Return the name of the attribute used for symbol visibility.
getVisibilityAttrName()50   static StringRef getVisibilityAttrName() { return "sym_visibility"; }
51 
52   //===--------------------------------------------------------------------===//
53   // Symbol Utilities
54   //===--------------------------------------------------------------------===//
55 
56   /// An enumeration detailing the different visibility types that a symbol may
57   /// have.
58   enum class Visibility {
59     /// The symbol is public and may be referenced anywhere internal or external
60     /// to the visible references in the IR.
61     Public,
62 
63     /// The symbol is private and may only be referenced by SymbolRefAttrs local
64     /// to the operations within the current symbol table.
65     Private,
66 
67     /// The symbol is visible to the current IR, which may include operations in
68     /// symbol tables above the one that owns the current symbol. `Nested`
69     /// visibility allows for referencing a symbol outside of its current symbol
70     /// table, while retaining the ability to observe all uses.
71     Nested,
72   };
73 
74   /// Returns true if the given operation defines a symbol.
75   static bool isSymbol(Operation *op);
76 
77   /// Returns the name of the given symbol operation.
78   static StringRef getSymbolName(Operation *symbol);
79   /// Sets the name of the given symbol operation.
80   static void setSymbolName(Operation *symbol, StringRef name);
81 
82   /// Returns the visibility of the given symbol operation.
83   static Visibility getSymbolVisibility(Operation *symbol);
84   /// Sets the visibility of the given symbol operation.
85   static void setSymbolVisibility(Operation *symbol, Visibility vis);
86 
87   /// Returns the operation registered with the given symbol name with the
88   /// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation
89   /// with the 'OpTrait::SymbolTable' trait.
90   static Operation *lookupSymbolIn(Operation *op, StringRef symbol);
91   static Operation *lookupSymbolIn(Operation *op, SymbolRefAttr symbol);
92 
93   /// Returns the operation registered with the given symbol name within the
94   /// closest parent operation of, or including, 'from' with the
95   /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
96   /// found.
97   static Operation *lookupNearestSymbolFrom(Operation *from, StringRef symbol);
98   static Operation *lookupNearestSymbolFrom(Operation *from,
99                                             SymbolRefAttr symbol);
100 
101   /// This class represents a specific symbol use.
102   class SymbolUse {
103   public:
SymbolUse(Operation * op,SymbolRefAttr symbolRef)104     SymbolUse(Operation *op, SymbolRefAttr symbolRef)
105         : owner(op), symbolRef(symbolRef) {}
106 
107     /// Return the operation user of this symbol reference.
getUser()108     Operation *getUser() const { return owner; }
109 
110     /// Return the symbol reference that this use represents.
getSymbolRef()111     SymbolRefAttr getSymbolRef() const { return symbolRef; }
112 
113   private:
114     /// The operation that this access is held by.
115     Operation *owner;
116 
117     /// The symbol reference that this use represents.
118     SymbolRefAttr symbolRef;
119   };
120 
121   /// This class implements a range of SymbolRef uses.
122   class UseRange {
123   public:
UseRange(std::vector<SymbolUse> && uses)124     UseRange(std::vector<SymbolUse> &&uses) : uses(std::move(uses)) {}
125 
126     using iterator = std::vector<SymbolUse>::const_iterator;
begin()127     iterator begin() const { return uses.begin(); }
end()128     iterator end() const { return uses.end(); }
129 
130   private:
131     std::vector<SymbolUse> uses;
132   };
133 
134   /// Get an iterator range for all of the uses, for any symbol, that are nested
135   /// within the given operation 'from'. This does not traverse into any nested
136   /// symbol tables, and will also only return uses on 'from' if it does not
137   /// also define a symbol table. This is because we treat the region as the
138   /// boundary of the symbol table, and not the op itself. This function returns
139   /// None if there are any unknown operations that may potentially be symbol
140   /// tables.
141   static Optional<UseRange> getSymbolUses(Operation *from);
142 
143   /// Get all of the uses of the given symbol that are nested within the given
144   /// operation 'from'. This does not traverse into any nested symbol tables,
145   /// and will also only return uses on 'from' if it does not also define a
146   /// symbol table. This is because we treat the region as the boundary of the
147   /// symbol table, and not the op itself. This function returns None if there
148   /// are any unknown operations that may potentially be symbol tables.
149   static Optional<UseRange> getSymbolUses(StringRef symbol, Operation *from);
150   static Optional<UseRange> getSymbolUses(Operation *symbol, Operation *from);
151 
152   /// Return if the given symbol is known to have no uses that are nested
153   /// within the given operation 'from'. This does not traverse into any nested
154   /// symbol tables, and will also only count uses on 'from' if it does not also
155   /// define a symbol table. This is because we treat the region as the boundary
156   /// of the symbol table, and not the op itself. This function will also return
157   /// false if there are any unknown operations that may potentially be symbol
158   /// tables. This doesn't necessarily mean that there are no uses, we just
159   /// can't conservatively prove it.
160   static bool symbolKnownUseEmpty(StringRef symbol, Operation *from);
161   static bool symbolKnownUseEmpty(Operation *symbol, Operation *from);
162 
163   /// Attempt to replace all uses of the given symbol 'oldSymbol' with the
164   /// provided symbol 'newSymbol' that are nested within the given operation
165   /// 'from'. This does not traverse into any nested symbol tables, and will
166   /// also only replace uses on 'from' if it does not also define a symbol
167   /// table. This is because we treat the region as the boundary of the symbol
168   /// table, and not the op itself. If there are any unknown operations that may
169   /// potentially be symbol tables, no uses are replaced and failure is
170   /// returned.
171   LLVM_NODISCARD static LogicalResult replaceAllSymbolUses(StringRef oldSymbol,
172                                                            StringRef newSymbol,
173                                                            Operation *from);
174   LLVM_NODISCARD static LogicalResult
175   replaceAllSymbolUses(Operation *oldSymbol, StringRef newSymbolName,
176                        Operation *from);
177 
178 private:
179   Operation *symbolTableOp;
180 
181   /// This is a mapping from a name to the symbol with that name.
182   llvm::StringMap<Operation *> symbolTable;
183 
184   /// This is used when name conflicts are detected.
185   unsigned uniquingCounter = 0;
186 };
187 
188 //===----------------------------------------------------------------------===//
189 // SymbolTable Trait Types
190 //===----------------------------------------------------------------------===//
191 
192 namespace OpTrait {
193 namespace impl {
194 LogicalResult verifySymbolTable(Operation *op);
195 LogicalResult verifySymbol(Operation *op);
196 } // namespace impl
197 
198 /// A trait used to provide symbol table functionalities to a region operation.
199 /// This operation must hold exactly 1 region. Once attached, all operations
200 /// that are directly within the region, i.e not including those within child
201 /// regions, that contain a 'SymbolTable::getSymbolAttrName()' StringAttr will
202 /// be verified to ensure that the names are uniqued. These operations must also
203 /// adhere to the constraints defined by the `Symbol` trait, even if they do not
204 /// inherit from it.
205 template <typename ConcreteType>
206 class SymbolTable : public TraitBase<ConcreteType, SymbolTable> {
207 public:
verifyTrait(Operation * op)208   static LogicalResult verifyTrait(Operation *op) {
209     return impl::verifySymbolTable(op);
210   }
211 
212   /// Look up a symbol with the specified name, returning null if no such
213   /// name exists. Symbol names never include the @ on them. Note: This
214   /// performs a linear scan of held symbols.
lookupSymbol(StringRef name)215   Operation *lookupSymbol(StringRef name) {
216     return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), name);
217   }
lookupSymbol(StringRef name)218   template <typename T> T lookupSymbol(StringRef name) {
219     return dyn_cast_or_null<T>(lookupSymbol(name));
220   }
221 };
222 
223 /// A trait used to define a symbol that can be used on operations within a
224 /// symbol table. Operations using this trait must adhere to the following:
225 ///   * Have a StringAttr attribute named 'SymbolTable::getSymbolAttrName()'.
226 template <typename ConcreteType>
227 class Symbol : public TraitBase<ConcreteType, Symbol> {
228 public:
229   using Visibility = mlir::SymbolTable::Visibility;
230 
verifyTrait(Operation * op)231   static LogicalResult verifyTrait(Operation *op) {
232     return impl::verifySymbol(op);
233   }
234 
235   /// Returns the name of this symbol.
getName()236   StringRef getName() {
237     return this->getOperation()
238         ->template getAttrOfType<StringAttr>(
239             mlir::SymbolTable::getSymbolAttrName())
240         .getValue();
241   }
242 
243   /// Set the name of this symbol.
setName(StringRef name)244   void setName(StringRef name) {
245     this->getOperation()->setAttr(
246         mlir::SymbolTable::getSymbolAttrName(),
247         StringAttr::get(name, this->getOperation()->getContext()));
248   }
249 
250   /// Returns the visibility of the current symbol.
getVisibility()251   Visibility getVisibility() {
252     return mlir::SymbolTable::getSymbolVisibility(this->getOperation());
253   }
254 
255   /// Sets the visibility of the current symbol.
setVisibility(Visibility vis)256   void setVisibility(Visibility vis) {
257     mlir::SymbolTable::setSymbolVisibility(this->getOperation(), vis);
258   }
259 
260   /// Get all of the uses of the current symbol that are nested within the given
261   /// operation 'from'.
262   /// Note: See mlir::SymbolTable::getSymbolUses for more details.
getSymbolUses(Operation * from)263   Optional<::mlir::SymbolTable::UseRange> getSymbolUses(Operation *from) {
264     return ::mlir::SymbolTable::getSymbolUses(this->getOperation(), from);
265   }
266 
267   /// Return if the current symbol is known to have no uses that are nested
268   /// within the given operation 'from'.
269   /// Note: See mlir::SymbolTable::symbolKnownUseEmpty for more details.
symbolKnownUseEmpty(Operation * from)270   bool symbolKnownUseEmpty(Operation *from) {
271     return ::mlir::SymbolTable::symbolKnownUseEmpty(this->getOperation(), from);
272   }
273 
274   /// Attempt to replace all uses of the current symbol with the provided symbol
275   /// 'newSymbol' that are nested within the given operation 'from'.
276   /// Note: See mlir::SymbolTable::replaceAllSymbolUses for more details.
replaceAllSymbolUses(StringRef newSymbol,Operation * from)277   LLVM_NODISCARD LogicalResult replaceAllSymbolUses(StringRef newSymbol,
278                                                     Operation *from) {
279     return ::mlir::SymbolTable::replaceAllSymbolUses(this->getOperation(),
280                                                      newSymbol, from);
281   }
282 };
283 
284 } // end namespace OpTrait
285 } // end namespace mlir
286 
287 #endif // MLIR_IR_SYMBOLTABLE_H
288