1 //===- AsmParserState.cpp -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Parser/AsmParserState.h"
10 #include "mlir/IR/Operation.h"
11 #include "mlir/IR/SymbolTable.h"
12 
13 using namespace mlir;
14 
15 //===----------------------------------------------------------------------===//
16 // AsmParserState::Impl
17 //===----------------------------------------------------------------------===//
18 
19 struct AsmParserState::Impl {
20   /// A map from a SymbolRefAttr to a range of uses.
21   using SymbolUseMap =
22       DenseMap<Attribute, SmallVector<SmallVector<llvm::SMRange>, 0>>;
23 
24   struct PartialOpDef {
PartialOpDefAsmParserState::Impl::PartialOpDef25     explicit PartialOpDef(const OperationName &opName) {
26       const auto *abstractOp = opName.getAbstractOperation();
27       if (abstractOp && abstractOp->hasTrait<OpTrait::SymbolTable>())
28         symbolTable = std::make_unique<SymbolUseMap>();
29     }
30 
31     /// Return if this operation is a symbol table.
isSymbolTableAsmParserState::Impl::PartialOpDef32     bool isSymbolTable() const { return symbolTable.get(); }
33 
34     /// If this operation is a symbol table, the following contains symbol uses
35     /// within this operation.
36     std::unique_ptr<SymbolUseMap> symbolTable;
37   };
38 
39   /// Resolve any symbol table uses in the IR.
40   void resolveSymbolUses();
41 
42   /// A mapping from operations in the input source file to their parser state.
43   SmallVector<std::unique_ptr<OperationDefinition>> operations;
44   DenseMap<Operation *, unsigned> operationToIdx;
45 
46   /// A mapping from blocks in the input source file to their parser state.
47   SmallVector<std::unique_ptr<BlockDefinition>> blocks;
48   DenseMap<Block *, unsigned> blocksToIdx;
49 
50   /// A set of value definitions that are placeholders for forward references.
51   /// This map should be empty if the parser finishes successfully.
52   DenseMap<Value, SmallVector<llvm::SMLoc>> placeholderValueUses;
53 
54   /// The symbol table operations within the IR.
55   SmallVector<std::pair<Operation *, std::unique_ptr<SymbolUseMap>>>
56       symbolTableOperations;
57 
58   /// A stack of partial operation definitions that have been started but not
59   /// yet finalized.
60   SmallVector<PartialOpDef> partialOperations;
61 
62   /// A stack of symbol use scopes. This is used when collecting symbol table
63   /// uses during parsing.
64   SmallVector<SymbolUseMap *> symbolUseScopes;
65 
66   /// A symbol table containing all of the symbol table operations in the IR.
67   SymbolTableCollection symbolTable;
68 };
69 
resolveSymbolUses()70 void AsmParserState::Impl::resolveSymbolUses() {
71   SmallVector<Operation *> symbolOps;
72   for (auto &opAndUseMapIt : symbolTableOperations) {
73     for (auto &it : *opAndUseMapIt.second) {
74       symbolOps.clear();
75       if (failed(symbolTable.lookupSymbolIn(
76               opAndUseMapIt.first, it.first.cast<SymbolRefAttr>(), symbolOps)))
77         continue;
78 
79       for (ArrayRef<llvm::SMRange> useRange : it.second) {
80         for (const auto &symIt : llvm::zip(symbolOps, useRange)) {
81           auto opIt = operationToIdx.find(std::get<0>(symIt));
82           if (opIt != operationToIdx.end())
83             operations[opIt->second]->symbolUses.push_back(std::get<1>(symIt));
84         }
85       }
86     }
87   }
88 }
89 
90 //===----------------------------------------------------------------------===//
91 // AsmParserState
92 //===----------------------------------------------------------------------===//
93 
AsmParserState()94 AsmParserState::AsmParserState() : impl(std::make_unique<Impl>()) {}
~AsmParserState()95 AsmParserState::~AsmParserState() {}
operator =(AsmParserState && other)96 AsmParserState &AsmParserState::operator=(AsmParserState &&other) {
97   impl = std::move(other.impl);
98   return *this;
99 }
100 
101 //===----------------------------------------------------------------------===//
102 // Access State
103 
getBlockDefs() const104 auto AsmParserState::getBlockDefs() const -> iterator_range<BlockDefIterator> {
105   return llvm::make_pointee_range(llvm::makeArrayRef(impl->blocks));
106 }
107 
getBlockDef(Block * block) const108 auto AsmParserState::getBlockDef(Block *block) const
109     -> const BlockDefinition * {
110   auto it = impl->blocksToIdx.find(block);
111   return it == impl->blocksToIdx.end() ? nullptr : &*impl->blocks[it->second];
112 }
113 
getOpDefs() const114 auto AsmParserState::getOpDefs() const -> iterator_range<OperationDefIterator> {
115   return llvm::make_pointee_range(llvm::makeArrayRef(impl->operations));
116 }
117 
getOpDef(Operation * op) const118 auto AsmParserState::getOpDef(Operation *op) const
119     -> const OperationDefinition * {
120   auto it = impl->operationToIdx.find(op);
121   return it == impl->operationToIdx.end() ? nullptr
122                                           : &*impl->operations[it->second];
123 }
124 
convertIdLocToRange(llvm::SMLoc loc)125 llvm::SMRange AsmParserState::convertIdLocToRange(llvm::SMLoc loc) {
126   if (!loc.isValid())
127     return llvm::SMRange();
128 
129   // Return if the given character is a valid identifier character.
130   auto isIdentifierChar = [](char c) {
131     return isalnum(c) || c == '$' || c == '.' || c == '_' || c == '-';
132   };
133 
134   const char *curPtr = loc.getPointer();
135   while (*curPtr && isIdentifierChar(*(++curPtr)))
136     continue;
137   return llvm::SMRange(loc, llvm::SMLoc::getFromPointer(curPtr));
138 }
139 
140 //===----------------------------------------------------------------------===//
141 // Populate State
142 
initialize(Operation * topLevelOp)143 void AsmParserState::initialize(Operation *topLevelOp) {
144   startOperationDefinition(topLevelOp->getName());
145 
146   // If the top-level operation is a symbol table, push a new symbol scope.
147   Impl::PartialOpDef &partialOpDef = impl->partialOperations.back();
148   if (partialOpDef.isSymbolTable())
149     impl->symbolUseScopes.push_back(partialOpDef.symbolTable.get());
150 }
151 
finalize(Operation * topLevelOp)152 void AsmParserState::finalize(Operation *topLevelOp) {
153   assert(!impl->partialOperations.empty() &&
154          "expected valid partial operation definition");
155   Impl::PartialOpDef partialOpDef = impl->partialOperations.pop_back_val();
156 
157   // If this operation is a symbol table, resolve any symbol uses.
158   if (partialOpDef.isSymbolTable()) {
159     impl->symbolTableOperations.emplace_back(
160         topLevelOp, std::move(partialOpDef.symbolTable));
161   }
162   impl->resolveSymbolUses();
163 }
164 
startOperationDefinition(const OperationName & opName)165 void AsmParserState::startOperationDefinition(const OperationName &opName) {
166   impl->partialOperations.emplace_back(opName);
167 }
168 
finalizeOperationDefinition(Operation * op,llvm::SMRange nameLoc,llvm::SMLoc endLoc,ArrayRef<std::pair<unsigned,llvm::SMLoc>> resultGroups)169 void AsmParserState::finalizeOperationDefinition(
170     Operation *op, llvm::SMRange nameLoc, llvm::SMLoc endLoc,
171     ArrayRef<std::pair<unsigned, llvm::SMLoc>> resultGroups) {
172   assert(!impl->partialOperations.empty() &&
173          "expected valid partial operation definition");
174   Impl::PartialOpDef partialOpDef = impl->partialOperations.pop_back_val();
175 
176   // Build the full operation definition.
177   std::unique_ptr<OperationDefinition> def =
178       std::make_unique<OperationDefinition>(op, nameLoc, endLoc);
179   for (auto &resultGroup : resultGroups)
180     def->resultGroups.emplace_back(resultGroup.first,
181                                    convertIdLocToRange(resultGroup.second));
182   impl->operationToIdx.try_emplace(op, impl->operations.size());
183   impl->operations.emplace_back(std::move(def));
184 
185   // If this operation is a symbol table, resolve any symbol uses.
186   if (partialOpDef.isSymbolTable()) {
187     impl->symbolTableOperations.emplace_back(
188         op, std::move(partialOpDef.symbolTable));
189   }
190 }
191 
startRegionDefinition()192 void AsmParserState::startRegionDefinition() {
193   assert(!impl->partialOperations.empty() &&
194          "expected valid partial operation definition");
195 
196   // If the parent operation of this region is a symbol table, we also push a
197   // new symbol scope.
198   Impl::PartialOpDef &partialOpDef = impl->partialOperations.back();
199   if (partialOpDef.isSymbolTable())
200     impl->symbolUseScopes.push_back(partialOpDef.symbolTable.get());
201 }
202 
finalizeRegionDefinition()203 void AsmParserState::finalizeRegionDefinition() {
204   assert(!impl->partialOperations.empty() &&
205          "expected valid partial operation definition");
206 
207   // If the parent operation of this region is a symbol table, pop the symbol
208   // scope for this region.
209   Impl::PartialOpDef &partialOpDef = impl->partialOperations.back();
210   if (partialOpDef.isSymbolTable())
211     impl->symbolUseScopes.pop_back();
212 }
213 
addDefinition(Block * block,llvm::SMLoc location)214 void AsmParserState::addDefinition(Block *block, llvm::SMLoc location) {
215   auto it = impl->blocksToIdx.find(block);
216   if (it == impl->blocksToIdx.end()) {
217     impl->blocksToIdx.try_emplace(block, impl->blocks.size());
218     impl->blocks.emplace_back(std::make_unique<BlockDefinition>(
219         block, convertIdLocToRange(location)));
220     return;
221   }
222 
223   // If an entry already exists, this was a forward declaration that now has a
224   // proper definition.
225   impl->blocks[it->second]->definition.loc = convertIdLocToRange(location);
226 }
227 
addDefinition(BlockArgument blockArg,llvm::SMLoc location)228 void AsmParserState::addDefinition(BlockArgument blockArg,
229                                    llvm::SMLoc location) {
230   auto it = impl->blocksToIdx.find(blockArg.getOwner());
231   assert(it != impl->blocksToIdx.end() &&
232          "expected owner block to have an entry");
233   BlockDefinition &def = *impl->blocks[it->second];
234   unsigned argIdx = blockArg.getArgNumber();
235 
236   if (def.arguments.size() <= argIdx)
237     def.arguments.resize(argIdx + 1);
238   def.arguments[argIdx] = SMDefinition(convertIdLocToRange(location));
239 }
240 
addUses(Value value,ArrayRef<llvm::SMLoc> locations)241 void AsmParserState::addUses(Value value, ArrayRef<llvm::SMLoc> locations) {
242   // Handle the case where the value is an operation result.
243   if (OpResult result = value.dyn_cast<OpResult>()) {
244     // Check to see if a definition for the parent operation has been recorded.
245     // If one hasn't, we treat the provided value as a placeholder value that
246     // will be refined further later.
247     Operation *parentOp = result.getOwner();
248     auto existingIt = impl->operationToIdx.find(parentOp);
249     if (existingIt == impl->operationToIdx.end()) {
250       impl->placeholderValueUses[value].append(locations.begin(),
251                                                locations.end());
252       return;
253     }
254 
255     // If a definition does exist, locate the value's result group and add the
256     // use. The result groups are ordered by increasing start index, so we just
257     // need to find the last group that has a smaller/equal start index.
258     unsigned resultNo = result.getResultNumber();
259     OperationDefinition &def = *impl->operations[existingIt->second];
260     for (auto &resultGroup : llvm::reverse(def.resultGroups)) {
261       if (resultNo >= resultGroup.first) {
262         for (llvm::SMLoc loc : locations)
263           resultGroup.second.uses.push_back(convertIdLocToRange(loc));
264         return;
265       }
266     }
267     llvm_unreachable("expected valid result group for value use");
268   }
269 
270   // Otherwise, this is a block argument.
271   BlockArgument arg = value.cast<BlockArgument>();
272   auto existingIt = impl->blocksToIdx.find(arg.getOwner());
273   assert(existingIt != impl->blocksToIdx.end() &&
274          "expected valid block definition for block argument");
275   BlockDefinition &blockDef = *impl->blocks[existingIt->second];
276   SMDefinition &argDef = blockDef.arguments[arg.getArgNumber()];
277   for (llvm::SMLoc loc : locations)
278     argDef.uses.emplace_back(convertIdLocToRange(loc));
279 }
280 
addUses(Block * block,ArrayRef<llvm::SMLoc> locations)281 void AsmParserState::addUses(Block *block, ArrayRef<llvm::SMLoc> locations) {
282   auto it = impl->blocksToIdx.find(block);
283   if (it == impl->blocksToIdx.end()) {
284     it = impl->blocksToIdx.try_emplace(block, impl->blocks.size()).first;
285     impl->blocks.emplace_back(std::make_unique<BlockDefinition>(block));
286   }
287 
288   BlockDefinition &def = *impl->blocks[it->second];
289   for (llvm::SMLoc loc : locations)
290     def.definition.uses.push_back(convertIdLocToRange(loc));
291 }
292 
addUses(SymbolRefAttr refAttr,ArrayRef<llvm::SMRange> locations)293 void AsmParserState::addUses(SymbolRefAttr refAttr,
294                              ArrayRef<llvm::SMRange> locations) {
295   // Ignore this symbol if no scopes are active.
296   if (impl->symbolUseScopes.empty())
297     return;
298 
299   assert((refAttr.getNestedReferences().size() + 1) == locations.size() &&
300          "expected the same number of references as provided locations");
301   (*impl->symbolUseScopes.back())[refAttr].emplace_back(locations.begin(),
302                                                         locations.end());
303 }
304 
refineDefinition(Value oldValue,Value newValue)305 void AsmParserState::refineDefinition(Value oldValue, Value newValue) {
306   auto it = impl->placeholderValueUses.find(oldValue);
307   assert(it != impl->placeholderValueUses.end() &&
308          "expected `oldValue` to be a placeholder");
309   addUses(newValue, it->second);
310   impl->placeholderValueUses.erase(oldValue);
311 }
312