1 //===- SymbolTable.h - MLIR Symbol Table Class ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_IR_SYMBOLTABLE_H
10 #define MLIR_IR_SYMBOLTABLE_H
11 
12 #include "mlir/IR/Attributes.h"
13 #include "mlir/IR/OpDefinition.h"
14 #include "llvm/ADT/SetVector.h"
15 #include "llvm/ADT/StringMap.h"
16 
17 namespace mlir {
18 class Identifier;
19 class Operation;
20 
21 /// This class allows for representing and managing the symbol table used by
22 /// operations with the 'SymbolTable' trait. Inserting into and erasing from
23 /// this SymbolTable will also insert and erase from the Operation given to it
24 /// at construction.
25 class SymbolTable {
26 public:
27   /// Build a symbol table with the symbols within the given operation.
28   SymbolTable(Operation *symbolTableOp);
29 
30   /// Look up a symbol with the specified name, returning null if no such
31   /// name exists. Names never include the @ on them.
32   Operation *lookup(StringRef name) const;
33   template <typename T>
lookup(StringRef name)34   T lookup(StringRef name) const {
35     return dyn_cast_or_null<T>(lookup(name));
36   }
37 
38   /// Look up a symbol with the specified name, returning null if no such
39   /// name exists. Names never include the @ on them.
40   Operation *lookup(StringAttr name) const;
41   template <typename T>
lookup(StringAttr name)42   T lookup(StringAttr name) const {
43     return dyn_cast_or_null<T>(lookup(name));
44   }
45 
46   /// Erase the given symbol from the table.
47   void erase(Operation *symbol);
48 
49   /// Insert a new symbol into the table, and rename it as necessary to avoid
50   /// collisions. Also insert at the specified location in the body of the
51   /// associated operation if it is not already there. It is asserted that the
52   /// symbol is not inside another operation.
53   void insert(Operation *symbol, Block::iterator insertPt = {});
54 
55   /// Return the name of the attribute used for symbol names.
getSymbolAttrName()56   static StringRef getSymbolAttrName() { return "sym_name"; }
57 
58   /// Returns the associated operation.
getOp()59   Operation *getOp() const { return symbolTableOp; }
60 
61   /// Return the name of the attribute used for symbol visibility.
getVisibilityAttrName()62   static StringRef getVisibilityAttrName() { return "sym_visibility"; }
63 
64   //===--------------------------------------------------------------------===//
65   // Symbol Utilities
66   //===--------------------------------------------------------------------===//
67 
68   /// An enumeration detailing the different visibility types that a symbol may
69   /// have.
70   enum class Visibility {
71     /// The symbol is public and may be referenced anywhere internal or external
72     /// to the visible references in the IR.
73     Public,
74 
75     /// The symbol is private and may only be referenced by SymbolRefAttrs local
76     /// to the operations within the current symbol table.
77     Private,
78 
79     /// The symbol is visible to the current IR, which may include operations in
80     /// symbol tables above the one that owns the current symbol. `Nested`
81     /// visibility allows for referencing a symbol outside of its current symbol
82     /// table, while retaining the ability to observe all uses.
83     Nested,
84   };
85 
86   /// Returns the name of the given symbol operation, aborting if no symbol is
87   /// present.
88   static StringAttr getSymbolName(Operation *symbol);
89 
90   /// Sets the name of the given symbol operation.
91   static void setSymbolName(Operation *symbol, StringAttr name);
setSymbolName(Operation * symbol,StringRef name)92   static void setSymbolName(Operation *symbol, StringRef name) {
93     setSymbolName(symbol, StringAttr::get(symbol->getContext(), name));
94   }
95 
96   /// Returns the visibility of the given symbol operation.
97   static Visibility getSymbolVisibility(Operation *symbol);
98   /// Sets the visibility of the given symbol operation.
99   static void setSymbolVisibility(Operation *symbol, Visibility vis);
100 
101   /// Returns the nearest symbol table from a given operation `from`. Returns
102   /// nullptr if no valid parent symbol table could be found.
103   static Operation *getNearestSymbolTable(Operation *from);
104 
105   /// Walks all symbol table operations nested within, and including, `op`. For
106   /// each symbol table operation, the provided callback is invoked with the op
107   /// and a boolean signifying if the symbols within that symbol table can be
108   /// treated as if all uses within the IR are visible to the caller.
109   /// `allSymUsesVisible` identifies whether all of the symbol uses of symbols
110   /// within `op` are visible.
111   static void walkSymbolTables(Operation *op, bool allSymUsesVisible,
112                                function_ref<void(Operation *, bool)> callback);
113 
114   /// Returns the operation registered with the given symbol name with the
115   /// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation
116   /// with the 'OpTrait::SymbolTable' trait.
117   static Operation *lookupSymbolIn(Operation *op, StringAttr symbol);
lookupSymbolIn(Operation * op,StringRef symbol)118   static Operation *lookupSymbolIn(Operation *op, StringRef symbol) {
119     return lookupSymbolIn(op, StringAttr::get(op->getContext(), symbol));
120   }
121   static Operation *lookupSymbolIn(Operation *op, SymbolRefAttr symbol);
122   /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced
123   /// by a given SymbolRefAttr. Returns failure if any of the nested references
124   /// could not be resolved.
125   static LogicalResult lookupSymbolIn(Operation *op, SymbolRefAttr symbol,
126                                       SmallVectorImpl<Operation *> &symbols);
127 
128   /// Returns the operation registered with the given symbol name within the
129   /// closest parent operation of, or including, 'from' with the
130   /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
131   /// found.
132   static Operation *lookupNearestSymbolFrom(Operation *from, StringAttr symbol);
133   static Operation *lookupNearestSymbolFrom(Operation *from,
134                                             SymbolRefAttr symbol);
135   template <typename T>
lookupNearestSymbolFrom(Operation * from,StringAttr symbol)136   static T lookupNearestSymbolFrom(Operation *from, StringAttr symbol) {
137     return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol));
138   }
139   template <typename T>
lookupNearestSymbolFrom(Operation * from,SymbolRefAttr symbol)140   static T lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol) {
141     return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol));
142   }
143 
144   /// This class represents a specific symbol use.
145   class SymbolUse {
146   public:
SymbolUse(Operation * op,SymbolRefAttr symbolRef)147     SymbolUse(Operation *op, SymbolRefAttr symbolRef)
148         : owner(op), symbolRef(symbolRef) {}
149 
150     /// Return the operation user of this symbol reference.
getUser()151     Operation *getUser() const { return owner; }
152 
153     /// Return the symbol reference that this use represents.
getSymbolRef()154     SymbolRefAttr getSymbolRef() const { return symbolRef; }
155 
156   private:
157     /// The operation that this access is held by.
158     Operation *owner;
159 
160     /// The symbol reference that this use represents.
161     SymbolRefAttr symbolRef;
162   };
163 
164   /// This class implements a range of SymbolRef uses.
165   class UseRange {
166   public:
UseRange(std::vector<SymbolUse> && uses)167     UseRange(std::vector<SymbolUse> &&uses) : uses(std::move(uses)) {}
168 
169     using iterator = std::vector<SymbolUse>::const_iterator;
begin()170     iterator begin() const { return uses.begin(); }
end()171     iterator end() const { return uses.end(); }
empty()172     bool empty() const { return uses.empty(); }
173 
174   private:
175     std::vector<SymbolUse> uses;
176   };
177 
178   /// Get an iterator range for all of the uses, for any symbol, that are nested
179   /// within the given operation 'from'. This does not traverse into any nested
180   /// symbol tables. This function returns None if there are any unknown
181   /// operations that may potentially be symbol tables.
182   static Optional<UseRange> getSymbolUses(Operation *from);
183   static Optional<UseRange> getSymbolUses(Region *from);
184 
185   /// Get all of the uses of the given symbol that are nested within the given
186   /// operation 'from'. This does not traverse into any nested symbol tables.
187   /// This function returns None if there are any unknown operations that may
188   /// potentially be symbol tables.
189   static Optional<UseRange> getSymbolUses(StringAttr symbol, Operation *from);
190   static Optional<UseRange> getSymbolUses(Operation *symbol, Operation *from);
191   static Optional<UseRange> getSymbolUses(StringAttr symbol, Region *from);
192   static Optional<UseRange> getSymbolUses(Operation *symbol, Region *from);
193 
194   /// Return if the given symbol is known to have no uses that are nested
195   /// within the given operation 'from'. This does not traverse into any nested
196   /// symbol tables. This function will also return false if there are any
197   /// unknown operations that may potentially be symbol tables. This doesn't
198   /// necessarily mean that there are no uses, we just can't conservatively
199   /// prove it.
200   static bool symbolKnownUseEmpty(StringAttr symbol, Operation *from);
201   static bool symbolKnownUseEmpty(Operation *symbol, Operation *from);
202   static bool symbolKnownUseEmpty(StringAttr symbol, Region *from);
203   static bool symbolKnownUseEmpty(Operation *symbol, Region *from);
204 
205   /// Attempt to replace all uses of the given symbol 'oldSymbol' with the
206   /// provided symbol 'newSymbol' that are nested within the given operation
207   /// 'from'. This does not traverse into any nested symbol tables. If there are
208   /// any unknown operations that may potentially be symbol tables, no uses are
209   /// replaced and failure is returned.
210   static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol,
211                                             StringAttr newSymbol,
212                                             Operation *from);
213   static LogicalResult replaceAllSymbolUses(Operation *oldSymbol,
214                                             StringAttr newSymbolName,
215                                             Operation *from);
216   static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol,
217                                             StringAttr newSymbol, Region *from);
218   static LogicalResult replaceAllSymbolUses(Operation *oldSymbol,
219                                             StringAttr newSymbolName,
220                                             Region *from);
221 
222 private:
223   Operation *symbolTableOp;
224 
225   /// This is a mapping from a name to the symbol with that name.  They key is
226   /// always known to be a StringAttr.
227   DenseMap<Attribute, Operation *> symbolTable;
228 
229   /// This is used when name conflicts are detected.
230   unsigned uniquingCounter = 0;
231 };
232 
233 raw_ostream &operator<<(raw_ostream &os, SymbolTable::Visibility visibility);
234 
235 //===----------------------------------------------------------------------===//
236 // SymbolTableCollection
237 //===----------------------------------------------------------------------===//
238 
239 /// This class represents a collection of `SymbolTable`s. This simplifies
240 /// certain algorithms that run recursively on nested symbol tables. Symbol
241 /// tables are constructed lazily to reduce the upfront cost of constructing
242 /// unnecessary tables.
243 class SymbolTableCollection {
244 public:
245   /// Look up a symbol with the specified name within the specified symbol table
246   /// operation, returning null if no such name exists.
247   Operation *lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol);
248   Operation *lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name);
249   template <typename T, typename NameT>
lookupSymbolIn(Operation * symbolTableOp,NameT && name)250   T lookupSymbolIn(Operation *symbolTableOp, NameT &&name) const {
251     return dyn_cast_or_null<T>(
252         lookupSymbolIn(symbolTableOp, std::forward<NameT>(name)));
253   }
254   /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced
255   /// by a given SymbolRefAttr when resolved within the provided symbol table
256   /// operation. Returns failure if any of the nested references could not be
257   /// resolved.
258   LogicalResult lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name,
259                                SmallVectorImpl<Operation *> &symbols);
260 
261   /// Returns the operation registered with the given symbol name within the
262   /// closest parent operation of, or including, 'from' with the
263   /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
264   /// found.
265   Operation *lookupNearestSymbolFrom(Operation *from, StringAttr symbol);
266   Operation *lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol);
267   template <typename T>
lookupNearestSymbolFrom(Operation * from,StringAttr symbol)268   T lookupNearestSymbolFrom(Operation *from, StringAttr symbol) {
269     return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol));
270   }
271   template <typename T>
lookupNearestSymbolFrom(Operation * from,SymbolRefAttr symbol)272   T lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol) {
273     return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol));
274   }
275 
276   /// Lookup, or create, a symbol table for an operation.
277   SymbolTable &getSymbolTable(Operation *op);
278 
279 private:
280   /// The constructed symbol tables nested within this table.
281   DenseMap<Operation *, std::unique_ptr<SymbolTable>> symbolTables;
282 };
283 
284 //===----------------------------------------------------------------------===//
285 // SymbolUserMap
286 //===----------------------------------------------------------------------===//
287 
288 /// This class represents a map of symbols to users, and provides efficient
289 /// implementations of symbol queries related to users; such as collecting the
290 /// users of a symbol, replacing all uses, etc.
291 class SymbolUserMap {
292 public:
293   /// Build a user map for all of the symbols defined in regions nested under
294   /// 'symbolTableOp'. A reference to the provided symbol table collection is
295   /// kept by the user map to ensure efficient lookups, thus the lifetime should
296   /// extend beyond that of this map.
297   SymbolUserMap(SymbolTableCollection &symbolTable, Operation *symbolTableOp);
298 
299   /// Return the users of the provided symbol operation.
getUsers(Operation * symbol)300   ArrayRef<Operation *> getUsers(Operation *symbol) const {
301     auto it = symbolToUsers.find(symbol);
302     return it != symbolToUsers.end() ? it->second.getArrayRef() : llvm::None;
303   }
304 
305   /// Return true if the given symbol has no uses.
use_empty(Operation * symbol)306   bool use_empty(Operation *symbol) const {
307     return !symbolToUsers.count(symbol);
308   }
309 
310   /// Replace all of the uses of the given symbol with `newSymbolName`.
311   void replaceAllUsesWith(Operation *symbol, StringAttr newSymbolName);
312 
313 private:
314   /// A reference to the symbol table used to construct this map.
315   SymbolTableCollection &symbolTable;
316 
317   /// A map of symbol operations to symbol users.
318   DenseMap<Operation *, SetVector<Operation *>> symbolToUsers;
319 };
320 
321 //===----------------------------------------------------------------------===//
322 // SymbolTable Trait Types
323 //===----------------------------------------------------------------------===//
324 
325 namespace detail {
326 LogicalResult verifySymbolTable(Operation *op);
327 LogicalResult verifySymbol(Operation *op);
328 } // namespace detail
329 
330 namespace OpTrait {
331 /// A trait used to provide symbol table functionalities to a region operation.
332 /// This operation must hold exactly 1 region. Once attached, all operations
333 /// that are directly within the region, i.e not including those within child
334 /// regions, that contain a 'SymbolTable::getSymbolAttrName()' StringAttr will
335 /// be verified to ensure that the names are uniqued. These operations must also
336 /// adhere to the constraints defined by the `Symbol` trait, even if they do not
337 /// inherit from it.
338 template <typename ConcreteType>
339 class SymbolTable : public TraitBase<ConcreteType, SymbolTable> {
340 public:
verifyTrait(Operation * op)341   static LogicalResult verifyTrait(Operation *op) {
342     return ::mlir::detail::verifySymbolTable(op);
343   }
344 
345   /// Look up a symbol with the specified name, returning null if no such
346   /// name exists. Symbol names never include the @ on them. Note: This
347   /// performs a linear scan of held symbols.
lookupSymbol(StringAttr name)348   Operation *lookupSymbol(StringAttr name) {
349     return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), name);
350   }
351   template <typename T>
lookupSymbol(StringAttr name)352   T lookupSymbol(StringAttr name) {
353     return dyn_cast_or_null<T>(lookupSymbol(name));
354   }
lookupSymbol(SymbolRefAttr symbol)355   Operation *lookupSymbol(SymbolRefAttr symbol) {
356     return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), symbol);
357   }
358   template <typename T>
lookupSymbol(SymbolRefAttr symbol)359   T lookupSymbol(SymbolRefAttr symbol) {
360     return dyn_cast_or_null<T>(lookupSymbol(symbol));
361   }
362 
lookupSymbol(StringRef name)363   Operation *lookupSymbol(StringRef name) {
364     return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), name);
365   }
366   template <typename T>
lookupSymbol(StringRef name)367   T lookupSymbol(StringRef name) {
368     return dyn_cast_or_null<T>(lookupSymbol(name));
369   }
370 };
371 
372 } // end namespace OpTrait
373 
374 //===----------------------------------------------------------------------===//
375 // Visibility parsing implementation.
376 //===----------------------------------------------------------------------===//
377 
378 namespace impl {
379 /// Parse an optional visibility attribute keyword (i.e., public, private, or
380 /// nested) without quotes in a string attribute named 'attrName'.
381 ParseResult parseOptionalVisibilityKeyword(OpAsmParser &parser,
382                                            NamedAttrList &attrs);
383 } // end namespace impl
384 
385 } // end namespace mlir
386 
387 /// Include the generated symbol interfaces.
388 #include "mlir/IR/SymbolInterfaces.h.inc"
389 
390 #endif // MLIR_IR_SYMBOLTABLE_H
391