1 //===- AsmParserState.cpp -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Parser/AsmParserState.h"
10 #include "mlir/IR/Operation.h"
11 #include "mlir/IR/SymbolTable.h"
12
13 using namespace mlir;
14
15 //===----------------------------------------------------------------------===//
16 // AsmParserState::Impl
17 //===----------------------------------------------------------------------===//
18
19 struct AsmParserState::Impl {
20 /// A map from a SymbolRefAttr to a range of uses.
21 using SymbolUseMap =
22 DenseMap<Attribute, SmallVector<SmallVector<llvm::SMRange>, 0>>;
23
24 struct PartialOpDef {
PartialOpDefAsmParserState::Impl::PartialOpDef25 explicit PartialOpDef(const OperationName &opName) {
26 const auto *abstractOp = opName.getAbstractOperation();
27 if (abstractOp && abstractOp->hasTrait<OpTrait::SymbolTable>())
28 symbolTable = std::make_unique<SymbolUseMap>();
29 }
30
31 /// Return if this operation is a symbol table.
isSymbolTableAsmParserState::Impl::PartialOpDef32 bool isSymbolTable() const { return symbolTable.get(); }
33
34 /// If this operation is a symbol table, the following contains symbol uses
35 /// within this operation.
36 std::unique_ptr<SymbolUseMap> symbolTable;
37 };
38
39 /// Resolve any symbol table uses in the IR.
40 void resolveSymbolUses();
41
42 /// A mapping from operations in the input source file to their parser state.
43 SmallVector<std::unique_ptr<OperationDefinition>> operations;
44 DenseMap<Operation *, unsigned> operationToIdx;
45
46 /// A mapping from blocks in the input source file to their parser state.
47 SmallVector<std::unique_ptr<BlockDefinition>> blocks;
48 DenseMap<Block *, unsigned> blocksToIdx;
49
50 /// A set of value definitions that are placeholders for forward references.
51 /// This map should be empty if the parser finishes successfully.
52 DenseMap<Value, SmallVector<llvm::SMLoc>> placeholderValueUses;
53
54 /// The symbol table operations within the IR.
55 SmallVector<std::pair<Operation *, std::unique_ptr<SymbolUseMap>>>
56 symbolTableOperations;
57
58 /// A stack of partial operation definitions that have been started but not
59 /// yet finalized.
60 SmallVector<PartialOpDef> partialOperations;
61
62 /// A stack of symbol use scopes. This is used when collecting symbol table
63 /// uses during parsing.
64 SmallVector<SymbolUseMap *> symbolUseScopes;
65
66 /// A symbol table containing all of the symbol table operations in the IR.
67 SymbolTableCollection symbolTable;
68 };
69
resolveSymbolUses()70 void AsmParserState::Impl::resolveSymbolUses() {
71 SmallVector<Operation *> symbolOps;
72 for (auto &opAndUseMapIt : symbolTableOperations) {
73 for (auto &it : *opAndUseMapIt.second) {
74 symbolOps.clear();
75 if (failed(symbolTable.lookupSymbolIn(
76 opAndUseMapIt.first, it.first.cast<SymbolRefAttr>(), symbolOps)))
77 continue;
78
79 for (ArrayRef<llvm::SMRange> useRange : it.second) {
80 for (const auto &symIt : llvm::zip(symbolOps, useRange)) {
81 auto opIt = operationToIdx.find(std::get<0>(symIt));
82 if (opIt != operationToIdx.end())
83 operations[opIt->second]->symbolUses.push_back(std::get<1>(symIt));
84 }
85 }
86 }
87 }
88 }
89
90 //===----------------------------------------------------------------------===//
91 // AsmParserState
92 //===----------------------------------------------------------------------===//
93
AsmParserState()94 AsmParserState::AsmParserState() : impl(std::make_unique<Impl>()) {}
~AsmParserState()95 AsmParserState::~AsmParserState() {}
operator =(AsmParserState && other)96 AsmParserState &AsmParserState::operator=(AsmParserState &&other) {
97 impl = std::move(other.impl);
98 return *this;
99 }
100
101 //===----------------------------------------------------------------------===//
102 // Access State
103
getBlockDefs() const104 auto AsmParserState::getBlockDefs() const -> iterator_range<BlockDefIterator> {
105 return llvm::make_pointee_range(llvm::makeArrayRef(impl->blocks));
106 }
107
getBlockDef(Block * block) const108 auto AsmParserState::getBlockDef(Block *block) const
109 -> const BlockDefinition * {
110 auto it = impl->blocksToIdx.find(block);
111 return it == impl->blocksToIdx.end() ? nullptr : &*impl->blocks[it->second];
112 }
113
getOpDefs() const114 auto AsmParserState::getOpDefs() const -> iterator_range<OperationDefIterator> {
115 return llvm::make_pointee_range(llvm::makeArrayRef(impl->operations));
116 }
117
getOpDef(Operation * op) const118 auto AsmParserState::getOpDef(Operation *op) const
119 -> const OperationDefinition * {
120 auto it = impl->operationToIdx.find(op);
121 return it == impl->operationToIdx.end() ? nullptr
122 : &*impl->operations[it->second];
123 }
124
convertIdLocToRange(llvm::SMLoc loc)125 llvm::SMRange AsmParserState::convertIdLocToRange(llvm::SMLoc loc) {
126 if (!loc.isValid())
127 return llvm::SMRange();
128
129 // Return if the given character is a valid identifier character.
130 auto isIdentifierChar = [](char c) {
131 return isalnum(c) || c == '$' || c == '.' || c == '_' || c == '-';
132 };
133
134 const char *curPtr = loc.getPointer();
135 while (*curPtr && isIdentifierChar(*(++curPtr)))
136 continue;
137 return llvm::SMRange(loc, llvm::SMLoc::getFromPointer(curPtr));
138 }
139
140 //===----------------------------------------------------------------------===//
141 // Populate State
142
initialize(Operation * topLevelOp)143 void AsmParserState::initialize(Operation *topLevelOp) {
144 startOperationDefinition(topLevelOp->getName());
145
146 // If the top-level operation is a symbol table, push a new symbol scope.
147 Impl::PartialOpDef &partialOpDef = impl->partialOperations.back();
148 if (partialOpDef.isSymbolTable())
149 impl->symbolUseScopes.push_back(partialOpDef.symbolTable.get());
150 }
151
finalize(Operation * topLevelOp)152 void AsmParserState::finalize(Operation *topLevelOp) {
153 assert(!impl->partialOperations.empty() &&
154 "expected valid partial operation definition");
155 Impl::PartialOpDef partialOpDef = impl->partialOperations.pop_back_val();
156
157 // If this operation is a symbol table, resolve any symbol uses.
158 if (partialOpDef.isSymbolTable()) {
159 impl->symbolTableOperations.emplace_back(
160 topLevelOp, std::move(partialOpDef.symbolTable));
161 }
162 impl->resolveSymbolUses();
163 }
164
startOperationDefinition(const OperationName & opName)165 void AsmParserState::startOperationDefinition(const OperationName &opName) {
166 impl->partialOperations.emplace_back(opName);
167 }
168
finalizeOperationDefinition(Operation * op,llvm::SMRange nameLoc,llvm::SMLoc endLoc,ArrayRef<std::pair<unsigned,llvm::SMLoc>> resultGroups)169 void AsmParserState::finalizeOperationDefinition(
170 Operation *op, llvm::SMRange nameLoc, llvm::SMLoc endLoc,
171 ArrayRef<std::pair<unsigned, llvm::SMLoc>> resultGroups) {
172 assert(!impl->partialOperations.empty() &&
173 "expected valid partial operation definition");
174 Impl::PartialOpDef partialOpDef = impl->partialOperations.pop_back_val();
175
176 // Build the full operation definition.
177 std::unique_ptr<OperationDefinition> def =
178 std::make_unique<OperationDefinition>(op, nameLoc, endLoc);
179 for (auto &resultGroup : resultGroups)
180 def->resultGroups.emplace_back(resultGroup.first,
181 convertIdLocToRange(resultGroup.second));
182 impl->operationToIdx.try_emplace(op, impl->operations.size());
183 impl->operations.emplace_back(std::move(def));
184
185 // If this operation is a symbol table, resolve any symbol uses.
186 if (partialOpDef.isSymbolTable()) {
187 impl->symbolTableOperations.emplace_back(
188 op, std::move(partialOpDef.symbolTable));
189 }
190 }
191
startRegionDefinition()192 void AsmParserState::startRegionDefinition() {
193 assert(!impl->partialOperations.empty() &&
194 "expected valid partial operation definition");
195
196 // If the parent operation of this region is a symbol table, we also push a
197 // new symbol scope.
198 Impl::PartialOpDef &partialOpDef = impl->partialOperations.back();
199 if (partialOpDef.isSymbolTable())
200 impl->symbolUseScopes.push_back(partialOpDef.symbolTable.get());
201 }
202
finalizeRegionDefinition()203 void AsmParserState::finalizeRegionDefinition() {
204 assert(!impl->partialOperations.empty() &&
205 "expected valid partial operation definition");
206
207 // If the parent operation of this region is a symbol table, pop the symbol
208 // scope for this region.
209 Impl::PartialOpDef &partialOpDef = impl->partialOperations.back();
210 if (partialOpDef.isSymbolTable())
211 impl->symbolUseScopes.pop_back();
212 }
213
addDefinition(Block * block,llvm::SMLoc location)214 void AsmParserState::addDefinition(Block *block, llvm::SMLoc location) {
215 auto it = impl->blocksToIdx.find(block);
216 if (it == impl->blocksToIdx.end()) {
217 impl->blocksToIdx.try_emplace(block, impl->blocks.size());
218 impl->blocks.emplace_back(std::make_unique<BlockDefinition>(
219 block, convertIdLocToRange(location)));
220 return;
221 }
222
223 // If an entry already exists, this was a forward declaration that now has a
224 // proper definition.
225 impl->blocks[it->second]->definition.loc = convertIdLocToRange(location);
226 }
227
addDefinition(BlockArgument blockArg,llvm::SMLoc location)228 void AsmParserState::addDefinition(BlockArgument blockArg,
229 llvm::SMLoc location) {
230 auto it = impl->blocksToIdx.find(blockArg.getOwner());
231 assert(it != impl->blocksToIdx.end() &&
232 "expected owner block to have an entry");
233 BlockDefinition &def = *impl->blocks[it->second];
234 unsigned argIdx = blockArg.getArgNumber();
235
236 if (def.arguments.size() <= argIdx)
237 def.arguments.resize(argIdx + 1);
238 def.arguments[argIdx] = SMDefinition(convertIdLocToRange(location));
239 }
240
addUses(Value value,ArrayRef<llvm::SMLoc> locations)241 void AsmParserState::addUses(Value value, ArrayRef<llvm::SMLoc> locations) {
242 // Handle the case where the value is an operation result.
243 if (OpResult result = value.dyn_cast<OpResult>()) {
244 // Check to see if a definition for the parent operation has been recorded.
245 // If one hasn't, we treat the provided value as a placeholder value that
246 // will be refined further later.
247 Operation *parentOp = result.getOwner();
248 auto existingIt = impl->operationToIdx.find(parentOp);
249 if (existingIt == impl->operationToIdx.end()) {
250 impl->placeholderValueUses[value].append(locations.begin(),
251 locations.end());
252 return;
253 }
254
255 // If a definition does exist, locate the value's result group and add the
256 // use. The result groups are ordered by increasing start index, so we just
257 // need to find the last group that has a smaller/equal start index.
258 unsigned resultNo = result.getResultNumber();
259 OperationDefinition &def = *impl->operations[existingIt->second];
260 for (auto &resultGroup : llvm::reverse(def.resultGroups)) {
261 if (resultNo >= resultGroup.first) {
262 for (llvm::SMLoc loc : locations)
263 resultGroup.second.uses.push_back(convertIdLocToRange(loc));
264 return;
265 }
266 }
267 llvm_unreachable("expected valid result group for value use");
268 }
269
270 // Otherwise, this is a block argument.
271 BlockArgument arg = value.cast<BlockArgument>();
272 auto existingIt = impl->blocksToIdx.find(arg.getOwner());
273 assert(existingIt != impl->blocksToIdx.end() &&
274 "expected valid block definition for block argument");
275 BlockDefinition &blockDef = *impl->blocks[existingIt->second];
276 SMDefinition &argDef = blockDef.arguments[arg.getArgNumber()];
277 for (llvm::SMLoc loc : locations)
278 argDef.uses.emplace_back(convertIdLocToRange(loc));
279 }
280
addUses(Block * block,ArrayRef<llvm::SMLoc> locations)281 void AsmParserState::addUses(Block *block, ArrayRef<llvm::SMLoc> locations) {
282 auto it = impl->blocksToIdx.find(block);
283 if (it == impl->blocksToIdx.end()) {
284 it = impl->blocksToIdx.try_emplace(block, impl->blocks.size()).first;
285 impl->blocks.emplace_back(std::make_unique<BlockDefinition>(block));
286 }
287
288 BlockDefinition &def = *impl->blocks[it->second];
289 for (llvm::SMLoc loc : locations)
290 def.definition.uses.push_back(convertIdLocToRange(loc));
291 }
292
addUses(SymbolRefAttr refAttr,ArrayRef<llvm::SMRange> locations)293 void AsmParserState::addUses(SymbolRefAttr refAttr,
294 ArrayRef<llvm::SMRange> locations) {
295 // Ignore this symbol if no scopes are active.
296 if (impl->symbolUseScopes.empty())
297 return;
298
299 assert((refAttr.getNestedReferences().size() + 1) == locations.size() &&
300 "expected the same number of references as provided locations");
301 (*impl->symbolUseScopes.back())[refAttr].emplace_back(locations.begin(),
302 locations.end());
303 }
304
refineDefinition(Value oldValue,Value newValue)305 void AsmParserState::refineDefinition(Value oldValue, Value newValue) {
306 auto it = impl->placeholderValueUses.find(oldValue);
307 assert(it != impl->placeholderValueUses.end() &&
308 "expected `oldValue` to be a placeholder");
309 addUses(newValue, it->second);
310 impl->placeholderValueUses.erase(oldValue);
311 }
312