1 //===- SymbolTable.h - MLIR Symbol Table Class ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_IR_SYMBOLTABLE_H
10 #define MLIR_IR_SYMBOLTABLE_H
11 
12 #include "mlir/IR/Attributes.h"
13 #include "mlir/IR/OpDefinition.h"
14 #include "llvm/ADT/StringMap.h"
15 
16 namespace mlir {
17 class Identifier;
18 class Operation;
19 
20 /// This class allows for representing and managing the symbol table used by
21 /// operations with the 'SymbolTable' trait. Inserting into and erasing from
22 /// this SymbolTable will also insert and erase from the Operation given to it
23 /// at construction.
24 class SymbolTable {
25 public:
26   /// Build a symbol table with the symbols within the given operation.
27   SymbolTable(Operation *symbolTableOp);
28 
29   /// Look up a symbol with the specified name, returning null if no such
30   /// name exists. Names never include the @ on them.
31   Operation *lookup(StringRef name) const;
lookup(StringRef name)32   template <typename T> T lookup(StringRef name) const {
33     return dyn_cast_or_null<T>(lookup(name));
34   }
35 
36   /// Erase the given symbol from the table.
37   void erase(Operation *symbol);
38 
39   /// Insert a new symbol into the table, and rename it as necessary to avoid
40   /// collisions. Also insert at the specified location in the body of the
41   /// associated operation.
42   void insert(Operation *symbol, Block::iterator insertPt = {});
43 
44   /// Return the name of the attribute used for symbol names.
getSymbolAttrName()45   static StringRef getSymbolAttrName() { return "sym_name"; }
46 
47   /// Returns the associated operation.
getOp()48   Operation *getOp() const { return symbolTableOp; }
49 
50   /// Return the name of the attribute used for symbol visibility.
getVisibilityAttrName()51   static StringRef getVisibilityAttrName() { return "sym_visibility"; }
52 
53   //===--------------------------------------------------------------------===//
54   // Symbol Utilities
55   //===--------------------------------------------------------------------===//
56 
57   /// An enumeration detailing the different visibility types that a symbol may
58   /// have.
59   enum class Visibility {
60     /// The symbol is public and may be referenced anywhere internal or external
61     /// to the visible references in the IR.
62     Public,
63 
64     /// The symbol is private and may only be referenced by SymbolRefAttrs local
65     /// to the operations within the current symbol table.
66     Private,
67 
68     /// The symbol is visible to the current IR, which may include operations in
69     /// symbol tables above the one that owns the current symbol. `Nested`
70     /// visibility allows for referencing a symbol outside of its current symbol
71     /// table, while retaining the ability to observe all uses.
72     Nested,
73   };
74 
75   /// Returns the name of the given symbol operation.
76   static StringRef getSymbolName(Operation *symbol);
77   /// Sets the name of the given symbol operation.
78   static void setSymbolName(Operation *symbol, StringRef name);
79 
80   /// Returns the visibility of the given symbol operation.
81   static Visibility getSymbolVisibility(Operation *symbol);
82   /// Sets the visibility of the given symbol operation.
83   static void setSymbolVisibility(Operation *symbol, Visibility vis);
84 
85   /// Returns the nearest symbol table from a given operation `from`. Returns
86   /// nullptr if no valid parent symbol table could be found.
87   static Operation *getNearestSymbolTable(Operation *from);
88 
89   /// Walks all symbol table operations nested within, and including, `op`. For
90   /// each symbol table operation, the provided callback is invoked with the op
91   /// and a boolean signifying if the symbols within that symbol table can be
92   /// treated as if all uses within the IR are visible to the caller.
93   /// `allSymUsesVisible` identifies whether all of the symbol uses of symbols
94   /// within `op` are visible.
95   static void walkSymbolTables(Operation *op, bool allSymUsesVisible,
96                                function_ref<void(Operation *, bool)> callback);
97 
98   /// Returns the operation registered with the given symbol name with the
99   /// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation
100   /// with the 'OpTrait::SymbolTable' trait.
101   static Operation *lookupSymbolIn(Operation *op, StringRef symbol);
102   static Operation *lookupSymbolIn(Operation *op, SymbolRefAttr symbol);
103   /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced
104   /// by a given SymbolRefAttr. Returns failure if any of the nested references
105   /// could not be resolved.
106   static LogicalResult lookupSymbolIn(Operation *op, SymbolRefAttr symbol,
107                                       SmallVectorImpl<Operation *> &symbols);
108 
109   /// Returns the operation registered with the given symbol name within the
110   /// closest parent operation of, or including, 'from' with the
111   /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
112   /// found.
113   static Operation *lookupNearestSymbolFrom(Operation *from, StringRef symbol);
114   static Operation *lookupNearestSymbolFrom(Operation *from,
115                                             SymbolRefAttr symbol);
116   template <typename T>
lookupNearestSymbolFrom(Operation * from,StringRef symbol)117   static T lookupNearestSymbolFrom(Operation *from, StringRef symbol) {
118     return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol));
119   }
120   template <typename T>
lookupNearestSymbolFrom(Operation * from,SymbolRefAttr symbol)121   static T lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol) {
122     return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol));
123   }
124 
125   /// This class represents a specific symbol use.
126   class SymbolUse {
127   public:
SymbolUse(Operation * op,SymbolRefAttr symbolRef)128     SymbolUse(Operation *op, SymbolRefAttr symbolRef)
129         : owner(op), symbolRef(symbolRef) {}
130 
131     /// Return the operation user of this symbol reference.
getUser()132     Operation *getUser() const { return owner; }
133 
134     /// Return the symbol reference that this use represents.
getSymbolRef()135     SymbolRefAttr getSymbolRef() const { return symbolRef; }
136 
137   private:
138     /// The operation that this access is held by.
139     Operation *owner;
140 
141     /// The symbol reference that this use represents.
142     SymbolRefAttr symbolRef;
143   };
144 
145   /// This class implements a range of SymbolRef uses.
146   class UseRange {
147   public:
UseRange(std::vector<SymbolUse> && uses)148     UseRange(std::vector<SymbolUse> &&uses) : uses(std::move(uses)) {}
149 
150     using iterator = std::vector<SymbolUse>::const_iterator;
begin()151     iterator begin() const { return uses.begin(); }
end()152     iterator end() const { return uses.end(); }
empty()153     bool empty() const { return uses.empty(); }
154 
155   private:
156     std::vector<SymbolUse> uses;
157   };
158 
159   /// Get an iterator range for all of the uses, for any symbol, that are nested
160   /// within the given operation 'from'. This does not traverse into any nested
161   /// symbol tables. This function returns None if there are any unknown
162   /// operations that may potentially be symbol tables.
163   static Optional<UseRange> getSymbolUses(Operation *from);
164   static Optional<UseRange> getSymbolUses(Region *from);
165 
166   /// Get all of the uses of the given symbol that are nested within the given
167   /// operation 'from'. This does not traverse into any nested symbol tables.
168   /// This function returns None if there are any unknown operations that may
169   /// potentially be symbol tables.
170   static Optional<UseRange> getSymbolUses(StringRef symbol, Operation *from);
171   static Optional<UseRange> getSymbolUses(Operation *symbol, Operation *from);
172   static Optional<UseRange> getSymbolUses(StringRef symbol, Region *from);
173   static Optional<UseRange> getSymbolUses(Operation *symbol, Region *from);
174 
175   /// Return if the given symbol is known to have no uses that are nested
176   /// within the given operation 'from'. This does not traverse into any nested
177   /// symbol tables. This function will also return false if there are any
178   /// unknown operations that may potentially be symbol tables. This doesn't
179   /// necessarily mean that there are no uses, we just can't conservatively
180   /// prove it.
181   static bool symbolKnownUseEmpty(StringRef symbol, Operation *from);
182   static bool symbolKnownUseEmpty(Operation *symbol, Operation *from);
183   static bool symbolKnownUseEmpty(StringRef symbol, Region *from);
184   static bool symbolKnownUseEmpty(Operation *symbol, Region *from);
185 
186   /// Attempt to replace all uses of the given symbol 'oldSymbol' with the
187   /// provided symbol 'newSymbol' that are nested within the given operation
188   /// 'from'. This does not traverse into any nested symbol tables. If there are
189   /// any unknown operations that may potentially be symbol tables, no uses are
190   /// replaced and failure is returned.
191   LLVM_NODISCARD static LogicalResult replaceAllSymbolUses(StringRef oldSymbol,
192                                                            StringRef newSymbol,
193                                                            Operation *from);
194   LLVM_NODISCARD static LogicalResult
195   replaceAllSymbolUses(Operation *oldSymbol, StringRef newSymbolName,
196                        Operation *from);
197   LLVM_NODISCARD static LogicalResult
198   replaceAllSymbolUses(StringRef oldSymbol, StringRef newSymbol, Region *from);
199   LLVM_NODISCARD static LogicalResult
200   replaceAllSymbolUses(Operation *oldSymbol, StringRef newSymbolName,
201                        Region *from);
202 
203 private:
204   Operation *symbolTableOp;
205 
206   /// This is a mapping from a name to the symbol with that name.
207   llvm::StringMap<Operation *> symbolTable;
208 
209   /// This is used when name conflicts are detected.
210   unsigned uniquingCounter = 0;
211 };
212 
213 //===----------------------------------------------------------------------===//
214 // SymbolTableCollection
215 //===----------------------------------------------------------------------===//
216 
217 /// This class represents a collection of `SymbolTable`s. This simplifies
218 /// certain algorithms that run recursively on nested symbol tables. Symbol
219 /// tables are constructed lazily to reduce the upfront cost of constructing
220 /// unnecessary tables.
221 class SymbolTableCollection {
222 public:
223   /// Look up a symbol with the specified name within the specified symbol table
224   /// operation, returning null if no such name exists.
225   Operation *lookupSymbolIn(Operation *symbolTableOp, StringRef symbol);
226   Operation *lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name);
227   template <typename T, typename NameT>
lookupSymbolIn(Operation * symbolTableOp,NameT && name)228   T lookupSymbolIn(Operation *symbolTableOp, NameT &&name) const {
229     return dyn_cast_or_null<T>(
230         lookupSymbolIn(symbolTableOp, std::forward<NameT>(name)));
231   }
232   /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced
233   /// by a given SymbolRefAttr when resolved within the provided symbol table
234   /// operation. Returns failure if any of the nested references could not be
235   /// resolved.
236   LogicalResult lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name,
237                                SmallVectorImpl<Operation *> &symbols);
238 
239   /// Returns the operation registered with the given symbol name within the
240   /// closest parent operation of, or including, 'from' with the
241   /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
242   /// found.
243   Operation *lookupNearestSymbolFrom(Operation *from, StringRef symbol);
244   Operation *lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol);
245   template <typename T>
lookupNearestSymbolFrom(Operation * from,StringRef symbol)246   T lookupNearestSymbolFrom(Operation *from, StringRef symbol) {
247     return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol));
248   }
249   template <typename T>
lookupNearestSymbolFrom(Operation * from,SymbolRefAttr symbol)250   T lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol) {
251     return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol));
252   }
253 
254   /// Lookup, or create, a symbol table for an operation.
255   SymbolTable &getSymbolTable(Operation *op);
256 
257 private:
258   /// The constructed symbol tables nested within this table.
259   DenseMap<Operation *, std::unique_ptr<SymbolTable>> symbolTables;
260 };
261 
262 //===----------------------------------------------------------------------===//
263 // SymbolTable Trait Types
264 //===----------------------------------------------------------------------===//
265 
266 namespace detail {
267 LogicalResult verifySymbolTable(Operation *op);
268 LogicalResult verifySymbol(Operation *op);
269 } // namespace detail
270 
271 namespace OpTrait {
272 /// A trait used to provide symbol table functionalities to a region operation.
273 /// This operation must hold exactly 1 region. Once attached, all operations
274 /// that are directly within the region, i.e not including those within child
275 /// regions, that contain a 'SymbolTable::getSymbolAttrName()' StringAttr will
276 /// be verified to ensure that the names are uniqued. These operations must also
277 /// adhere to the constraints defined by the `Symbol` trait, even if they do not
278 /// inherit from it.
279 template <typename ConcreteType>
280 class SymbolTable : public TraitBase<ConcreteType, SymbolTable> {
281 public:
verifyTrait(Operation * op)282   static LogicalResult verifyTrait(Operation *op) {
283     return ::mlir::detail::verifySymbolTable(op);
284   }
285 
286   /// Look up a symbol with the specified name, returning null if no such
287   /// name exists. Symbol names never include the @ on them. Note: This
288   /// performs a linear scan of held symbols.
lookupSymbol(StringRef name)289   Operation *lookupSymbol(StringRef name) {
290     return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), name);
291   }
lookupSymbol(StringRef name)292   template <typename T> T lookupSymbol(StringRef name) {
293     return dyn_cast_or_null<T>(lookupSymbol(name));
294   }
lookupSymbol(SymbolRefAttr symbol)295   Operation *lookupSymbol(SymbolRefAttr symbol) {
296     return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), symbol);
297   }
298   template <typename T>
lookupSymbol(SymbolRefAttr symbol)299   T lookupSymbol(SymbolRefAttr symbol) {
300     return dyn_cast_or_null<T>(lookupSymbol(symbol));
301   }
302 };
303 
304 } // end namespace OpTrait
305 
306 //===----------------------------------------------------------------------===//
307 // Visibility parsing implementation.
308 //===----------------------------------------------------------------------===//
309 
310 namespace impl {
311 /// Parse an optional visibility attribute keyword (i.e., public, private, or
312 /// nested) without quotes in a string attribute named 'attrName'.
313 ParseResult parseOptionalVisibilityKeyword(OpAsmParser &parser,
314                                            NamedAttrList &attrs);
315 } // end namespace impl
316 
317 } // end namespace mlir
318 
319 /// Include the generated symbol interfaces.
320 #include "mlir/IR/SymbolInterfaces.h.inc"
321 
322 #endif // MLIR_IR_SYMBOLTABLE_H
323