1 //===- SymbolTable.h - MLIR Symbol Table Class ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_IR_SYMBOLTABLE_H
10 #define MLIR_IR_SYMBOLTABLE_H
11 
12 #include "mlir/IR/Attributes.h"
13 #include "mlir/IR/OpDefinition.h"
14 #include "llvm/ADT/SetVector.h"
15 #include "llvm/ADT/StringMap.h"
16 
17 namespace mlir {
18 class Identifier;
19 class Operation;
20 
21 /// This class allows for representing and managing the symbol table used by
22 /// operations with the 'SymbolTable' trait. Inserting into and erasing from
23 /// this SymbolTable will also insert and erase from the Operation given to it
24 /// at construction.
25 class SymbolTable {
26 public:
27   /// Build a symbol table with the symbols within the given operation.
28   SymbolTable(Operation *symbolTableOp);
29 
30   /// Look up a symbol with the specified name, returning null if no such
31   /// name exists. Names never include the @ on them.
32   Operation *lookup(StringRef name) const;
lookup(StringRef name)33   template <typename T> T lookup(StringRef name) const {
34     return dyn_cast_or_null<T>(lookup(name));
35   }
36 
37   /// Erase the given symbol from the table.
38   void erase(Operation *symbol);
39 
40   /// Insert a new symbol into the table, and rename it as necessary to avoid
41   /// collisions. Also insert at the specified location in the body of the
42   /// associated operation if it is not already there. It is asserted that the
43   /// symbol is not inside another operation.
44   void insert(Operation *symbol, Block::iterator insertPt = {});
45 
46   /// Return the name of the attribute used for symbol names.
getSymbolAttrName()47   static StringRef getSymbolAttrName() { return "sym_name"; }
48 
49   /// Returns the associated operation.
getOp()50   Operation *getOp() const { return symbolTableOp; }
51 
52   /// Return the name of the attribute used for symbol visibility.
getVisibilityAttrName()53   static StringRef getVisibilityAttrName() { return "sym_visibility"; }
54 
55   //===--------------------------------------------------------------------===//
56   // Symbol Utilities
57   //===--------------------------------------------------------------------===//
58 
59   /// An enumeration detailing the different visibility types that a symbol may
60   /// have.
61   enum class Visibility {
62     /// The symbol is public and may be referenced anywhere internal or external
63     /// to the visible references in the IR.
64     Public,
65 
66     /// The symbol is private and may only be referenced by SymbolRefAttrs local
67     /// to the operations within the current symbol table.
68     Private,
69 
70     /// The symbol is visible to the current IR, which may include operations in
71     /// symbol tables above the one that owns the current symbol. `Nested`
72     /// visibility allows for referencing a symbol outside of its current symbol
73     /// table, while retaining the ability to observe all uses.
74     Nested,
75   };
76 
77   /// Returns the name of the given symbol operation.
78   static StringRef getSymbolName(Operation *symbol);
79   /// Sets the name of the given symbol operation.
80   static void setSymbolName(Operation *symbol, StringRef name);
81 
82   /// Returns the visibility of the given symbol operation.
83   static Visibility getSymbolVisibility(Operation *symbol);
84   /// Sets the visibility of the given symbol operation.
85   static void setSymbolVisibility(Operation *symbol, Visibility vis);
86 
87   /// Returns the nearest symbol table from a given operation `from`. Returns
88   /// nullptr if no valid parent symbol table could be found.
89   static Operation *getNearestSymbolTable(Operation *from);
90 
91   /// Walks all symbol table operations nested within, and including, `op`. For
92   /// each symbol table operation, the provided callback is invoked with the op
93   /// and a boolean signifying if the symbols within that symbol table can be
94   /// treated as if all uses within the IR are visible to the caller.
95   /// `allSymUsesVisible` identifies whether all of the symbol uses of symbols
96   /// within `op` are visible.
97   static void walkSymbolTables(Operation *op, bool allSymUsesVisible,
98                                function_ref<void(Operation *, bool)> callback);
99 
100   /// Returns the operation registered with the given symbol name with the
101   /// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation
102   /// with the 'OpTrait::SymbolTable' trait.
103   static Operation *lookupSymbolIn(Operation *op, StringRef symbol);
104   static Operation *lookupSymbolIn(Operation *op, SymbolRefAttr symbol);
105   /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced
106   /// by a given SymbolRefAttr. Returns failure if any of the nested references
107   /// could not be resolved.
108   static LogicalResult lookupSymbolIn(Operation *op, SymbolRefAttr symbol,
109                                       SmallVectorImpl<Operation *> &symbols);
110 
111   /// Returns the operation registered with the given symbol name within the
112   /// closest parent operation of, or including, 'from' with the
113   /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
114   /// found.
115   static Operation *lookupNearestSymbolFrom(Operation *from, StringRef symbol);
116   static Operation *lookupNearestSymbolFrom(Operation *from,
117                                             SymbolRefAttr symbol);
118   template <typename T>
lookupNearestSymbolFrom(Operation * from,StringRef symbol)119   static T lookupNearestSymbolFrom(Operation *from, StringRef symbol) {
120     return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol));
121   }
122   template <typename T>
lookupNearestSymbolFrom(Operation * from,SymbolRefAttr symbol)123   static T lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol) {
124     return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol));
125   }
126 
127   /// This class represents a specific symbol use.
128   class SymbolUse {
129   public:
SymbolUse(Operation * op,SymbolRefAttr symbolRef)130     SymbolUse(Operation *op, SymbolRefAttr symbolRef)
131         : owner(op), symbolRef(symbolRef) {}
132 
133     /// Return the operation user of this symbol reference.
getUser()134     Operation *getUser() const { return owner; }
135 
136     /// Return the symbol reference that this use represents.
getSymbolRef()137     SymbolRefAttr getSymbolRef() const { return symbolRef; }
138 
139   private:
140     /// The operation that this access is held by.
141     Operation *owner;
142 
143     /// The symbol reference that this use represents.
144     SymbolRefAttr symbolRef;
145   };
146 
147   /// This class implements a range of SymbolRef uses.
148   class UseRange {
149   public:
UseRange(std::vector<SymbolUse> && uses)150     UseRange(std::vector<SymbolUse> &&uses) : uses(std::move(uses)) {}
151 
152     using iterator = std::vector<SymbolUse>::const_iterator;
begin()153     iterator begin() const { return uses.begin(); }
end()154     iterator end() const { return uses.end(); }
empty()155     bool empty() const { return uses.empty(); }
156 
157   private:
158     std::vector<SymbolUse> uses;
159   };
160 
161   /// Get an iterator range for all of the uses, for any symbol, that are nested
162   /// within the given operation 'from'. This does not traverse into any nested
163   /// symbol tables. This function returns None if there are any unknown
164   /// operations that may potentially be symbol tables.
165   static Optional<UseRange> getSymbolUses(Operation *from);
166   static Optional<UseRange> getSymbolUses(Region *from);
167 
168   /// Get all of the uses of the given symbol that are nested within the given
169   /// operation 'from'. This does not traverse into any nested symbol tables.
170   /// This function returns None if there are any unknown operations that may
171   /// potentially be symbol tables.
172   static Optional<UseRange> getSymbolUses(StringRef symbol, Operation *from);
173   static Optional<UseRange> getSymbolUses(Operation *symbol, Operation *from);
174   static Optional<UseRange> getSymbolUses(StringRef symbol, Region *from);
175   static Optional<UseRange> getSymbolUses(Operation *symbol, Region *from);
176 
177   /// Return if the given symbol is known to have no uses that are nested
178   /// within the given operation 'from'. This does not traverse into any nested
179   /// symbol tables. This function will also return false if there are any
180   /// unknown operations that may potentially be symbol tables. This doesn't
181   /// necessarily mean that there are no uses, we just can't conservatively
182   /// prove it.
183   static bool symbolKnownUseEmpty(StringRef symbol, Operation *from);
184   static bool symbolKnownUseEmpty(Operation *symbol, Operation *from);
185   static bool symbolKnownUseEmpty(StringRef symbol, Region *from);
186   static bool symbolKnownUseEmpty(Operation *symbol, Region *from);
187 
188   /// Attempt to replace all uses of the given symbol 'oldSymbol' with the
189   /// provided symbol 'newSymbol' that are nested within the given operation
190   /// 'from'. This does not traverse into any nested symbol tables. If there are
191   /// any unknown operations that may potentially be symbol tables, no uses are
192   /// replaced and failure is returned.
193   static LogicalResult replaceAllSymbolUses(StringRef oldSymbol,
194                                             StringRef newSymbol,
195                                             Operation *from);
196   static LogicalResult replaceAllSymbolUses(Operation *oldSymbol,
197                                             StringRef newSymbolName,
198                                             Operation *from);
199   static LogicalResult replaceAllSymbolUses(StringRef oldSymbol,
200                                             StringRef newSymbol, Region *from);
201   static LogicalResult replaceAllSymbolUses(Operation *oldSymbol,
202                                             StringRef newSymbolName,
203                                             Region *from);
204 
205 private:
206   Operation *symbolTableOp;
207 
208   /// This is a mapping from a name to the symbol with that name.
209   llvm::StringMap<Operation *> symbolTable;
210 
211   /// This is used when name conflicts are detected.
212   unsigned uniquingCounter = 0;
213 };
214 
215 raw_ostream &operator<<(raw_ostream &os, SymbolTable::Visibility visibility);
216 
217 //===----------------------------------------------------------------------===//
218 // SymbolTableCollection
219 //===----------------------------------------------------------------------===//
220 
221 /// This class represents a collection of `SymbolTable`s. This simplifies
222 /// certain algorithms that run recursively on nested symbol tables. Symbol
223 /// tables are constructed lazily to reduce the upfront cost of constructing
224 /// unnecessary tables.
225 class SymbolTableCollection {
226 public:
227   /// Look up a symbol with the specified name within the specified symbol table
228   /// operation, returning null if no such name exists.
229   Operation *lookupSymbolIn(Operation *symbolTableOp, StringRef symbol);
230   Operation *lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name);
231   template <typename T, typename NameT>
lookupSymbolIn(Operation * symbolTableOp,NameT && name)232   T lookupSymbolIn(Operation *symbolTableOp, NameT &&name) const {
233     return dyn_cast_or_null<T>(
234         lookupSymbolIn(symbolTableOp, std::forward<NameT>(name)));
235   }
236   /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced
237   /// by a given SymbolRefAttr when resolved within the provided symbol table
238   /// operation. Returns failure if any of the nested references could not be
239   /// resolved.
240   LogicalResult lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name,
241                                SmallVectorImpl<Operation *> &symbols);
242 
243   /// Returns the operation registered with the given symbol name within the
244   /// closest parent operation of, or including, 'from' with the
245   /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
246   /// found.
247   Operation *lookupNearestSymbolFrom(Operation *from, StringRef symbol);
248   Operation *lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol);
249   template <typename T>
lookupNearestSymbolFrom(Operation * from,StringRef symbol)250   T lookupNearestSymbolFrom(Operation *from, StringRef symbol) {
251     return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol));
252   }
253   template <typename T>
lookupNearestSymbolFrom(Operation * from,SymbolRefAttr symbol)254   T lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol) {
255     return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol));
256   }
257 
258   /// Lookup, or create, a symbol table for an operation.
259   SymbolTable &getSymbolTable(Operation *op);
260 
261 private:
262   /// The constructed symbol tables nested within this table.
263   DenseMap<Operation *, std::unique_ptr<SymbolTable>> symbolTables;
264 };
265 
266 //===----------------------------------------------------------------------===//
267 // SymbolUserMap
268 //===----------------------------------------------------------------------===//
269 
270 /// This class represents a map of symbols to users, and provides efficient
271 /// implementations of symbol queries related to users; such as collecting the
272 /// users of a symbol, replacing all uses, etc.
273 class SymbolUserMap {
274 public:
275   /// Build a user map for all of the symbols defined in regions nested under
276   /// 'symbolTableOp'. A reference to the provided symbol table collection is
277   /// kept by the user map to ensure efficient lookups, thus the lifetime should
278   /// extend beyond that of this map.
279   SymbolUserMap(SymbolTableCollection &symbolTable, Operation *symbolTableOp);
280 
281   /// Return the users of the provided symbol operation.
getUsers(Operation * symbol)282   ArrayRef<Operation *> getUsers(Operation *symbol) const {
283     auto it = symbolToUsers.find(symbol);
284     return it != symbolToUsers.end() ? it->second.getArrayRef() : llvm::None;
285   }
286 
287   /// Return true if the given symbol has no uses.
use_empty(Operation * symbol)288   bool use_empty(Operation *symbol) const {
289     return !symbolToUsers.count(symbol);
290   }
291 
292   /// Replace all of the uses of the given symbol with `newSymbolName`.
293   void replaceAllUsesWith(Operation *symbol, StringRef newSymbolName);
294 
295 private:
296   /// A reference to the symbol table used to construct this map.
297   SymbolTableCollection &symbolTable;
298 
299   /// A map of symbol operations to symbol users.
300   DenseMap<Operation *, SetVector<Operation *>> symbolToUsers;
301 };
302 
303 //===----------------------------------------------------------------------===//
304 // SymbolTable Trait Types
305 //===----------------------------------------------------------------------===//
306 
307 namespace detail {
308 LogicalResult verifySymbolTable(Operation *op);
309 LogicalResult verifySymbol(Operation *op);
310 } // namespace detail
311 
312 namespace OpTrait {
313 /// A trait used to provide symbol table functionalities to a region operation.
314 /// This operation must hold exactly 1 region. Once attached, all operations
315 /// that are directly within the region, i.e not including those within child
316 /// regions, that contain a 'SymbolTable::getSymbolAttrName()' StringAttr will
317 /// be verified to ensure that the names are uniqued. These operations must also
318 /// adhere to the constraints defined by the `Symbol` trait, even if they do not
319 /// inherit from it.
320 template <typename ConcreteType>
321 class SymbolTable : public TraitBase<ConcreteType, SymbolTable> {
322 public:
verifyTrait(Operation * op)323   static LogicalResult verifyTrait(Operation *op) {
324     return ::mlir::detail::verifySymbolTable(op);
325   }
326 
327   /// Look up a symbol with the specified name, returning null if no such
328   /// name exists. Symbol names never include the @ on them. Note: This
329   /// performs a linear scan of held symbols.
lookupSymbol(StringRef name)330   Operation *lookupSymbol(StringRef name) {
331     return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), name);
332   }
lookupSymbol(StringRef name)333   template <typename T> T lookupSymbol(StringRef name) {
334     return dyn_cast_or_null<T>(lookupSymbol(name));
335   }
lookupSymbol(SymbolRefAttr symbol)336   Operation *lookupSymbol(SymbolRefAttr symbol) {
337     return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), symbol);
338   }
lookupSymbol(SymbolRefAttr symbol)339   template <typename T> T lookupSymbol(SymbolRefAttr symbol) {
340     return dyn_cast_or_null<T>(lookupSymbol(symbol));
341   }
342 };
343 
344 } // end namespace OpTrait
345 
346 //===----------------------------------------------------------------------===//
347 // Visibility parsing implementation.
348 //===----------------------------------------------------------------------===//
349 
350 namespace impl {
351 /// Parse an optional visibility attribute keyword (i.e., public, private, or
352 /// nested) without quotes in a string attribute named 'attrName'.
353 ParseResult parseOptionalVisibilityKeyword(OpAsmParser &parser,
354                                            NamedAttrList &attrs);
355 } // end namespace impl
356 
357 } // end namespace mlir
358 
359 /// Include the generated symbol interfaces.
360 #include "mlir/IR/SymbolInterfaces.h.inc"
361 
362 #endif // MLIR_IR_SYMBOLTABLE_H
363