1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements MLIR to byte-code generation and the interpreter.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "ByteCode.h"
14 #include "mlir/Analysis/Liveness.h"
15 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
16 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/RegionGraphTraits.h"
19 #include "llvm/ADT/IntervalMap.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23
24 #define DEBUG_TYPE "pdl-bytecode"
25
26 using namespace mlir;
27 using namespace mlir::detail;
28
29 //===----------------------------------------------------------------------===//
30 // PDLByteCodePattern
31 //===----------------------------------------------------------------------===//
32
create(pdl_interp::RecordMatchOp matchOp,ByteCodeAddr rewriterAddr)33 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
34 ByteCodeAddr rewriterAddr) {
35 SmallVector<StringRef, 8> generatedOps;
36 if (ArrayAttr generatedOpsAttr = matchOp.generatedOpsAttr())
37 generatedOps =
38 llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
39
40 PatternBenefit benefit = matchOp.benefit();
41 MLIRContext *ctx = matchOp.getContext();
42
43 // Check to see if this is pattern matches a specific operation type.
44 if (Optional<StringRef> rootKind = matchOp.rootKind())
45 return PDLByteCodePattern(rewriterAddr, *rootKind, generatedOps, benefit,
46 ctx);
47 return PDLByteCodePattern(rewriterAddr, generatedOps, benefit, ctx,
48 MatchAnyOpTypeTag());
49 }
50
51 //===----------------------------------------------------------------------===//
52 // PDLByteCodeMutableState
53 //===----------------------------------------------------------------------===//
54
55 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
56 /// to the position of the pattern within the range returned by
57 /// `PDLByteCode::getPatterns`.
updatePatternBenefit(unsigned patternIndex,PatternBenefit benefit)58 void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
59 PatternBenefit benefit) {
60 currentPatternBenefits[patternIndex] = benefit;
61 }
62
63 //===----------------------------------------------------------------------===//
64 // Bytecode OpCodes
65 //===----------------------------------------------------------------------===//
66
67 namespace {
68 enum OpCode : ByteCodeField {
69 /// Apply an externally registered constraint.
70 ApplyConstraint,
71 /// Apply an externally registered rewrite.
72 ApplyRewrite,
73 /// Check if two generic values are equal.
74 AreEqual,
75 /// Unconditional branch.
76 Branch,
77 /// Compare the operand count of an operation with a constant.
78 CheckOperandCount,
79 /// Compare the name of an operation with a constant.
80 CheckOperationName,
81 /// Compare the result count of an operation with a constant.
82 CheckResultCount,
83 /// Invoke a native creation method.
84 CreateNative,
85 /// Create an operation.
86 CreateOperation,
87 /// Erase an operation.
88 EraseOp,
89 /// Terminate a matcher or rewrite sequence.
90 Finalize,
91 /// Get a specific attribute of an operation.
92 GetAttribute,
93 /// Get the type of an attribute.
94 GetAttributeType,
95 /// Get the defining operation of a value.
96 GetDefiningOp,
97 /// Get a specific operand of an operation.
98 GetOperand0,
99 GetOperand1,
100 GetOperand2,
101 GetOperand3,
102 GetOperandN,
103 /// Get a specific result of an operation.
104 GetResult0,
105 GetResult1,
106 GetResult2,
107 GetResult3,
108 GetResultN,
109 /// Get the type of a value.
110 GetValueType,
111 /// Check if a generic value is not null.
112 IsNotNull,
113 /// Record a successful pattern match.
114 RecordMatch,
115 /// Replace an operation.
116 ReplaceOp,
117 /// Compare an attribute with a set of constants.
118 SwitchAttribute,
119 /// Compare the operand count of an operation with a set of constants.
120 SwitchOperandCount,
121 /// Compare the name of an operation with a set of constants.
122 SwitchOperationName,
123 /// Compare the result count of an operation with a set of constants.
124 SwitchResultCount,
125 /// Compare a type with a set of constants.
126 SwitchType,
127 };
128
129 enum class PDLValueKind { Attribute, Operation, Type, Value };
130 } // end anonymous namespace
131
132 //===----------------------------------------------------------------------===//
133 // ByteCode Generation
134 //===----------------------------------------------------------------------===//
135
136 //===----------------------------------------------------------------------===//
137 // Generator
138
139 namespace {
140 struct ByteCodeWriter;
141
142 /// This class represents the main generator for the pattern bytecode.
143 class Generator {
144 public:
Generator(MLIRContext * ctx,std::vector<const void * > & uniquedData,SmallVectorImpl<ByteCodeField> & matcherByteCode,SmallVectorImpl<ByteCodeField> & rewriterByteCode,SmallVectorImpl<PDLByteCodePattern> & patterns,ByteCodeField & maxValueMemoryIndex,llvm::StringMap<PDLConstraintFunction> & constraintFns,llvm::StringMap<PDLCreateFunction> & createFns,llvm::StringMap<PDLRewriteFunction> & rewriteFns)145 Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
146 SmallVectorImpl<ByteCodeField> &matcherByteCode,
147 SmallVectorImpl<ByteCodeField> &rewriterByteCode,
148 SmallVectorImpl<PDLByteCodePattern> &patterns,
149 ByteCodeField &maxValueMemoryIndex,
150 llvm::StringMap<PDLConstraintFunction> &constraintFns,
151 llvm::StringMap<PDLCreateFunction> &createFns,
152 llvm::StringMap<PDLRewriteFunction> &rewriteFns)
153 : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
154 rewriterByteCode(rewriterByteCode), patterns(patterns),
155 maxValueMemoryIndex(maxValueMemoryIndex) {
156 for (auto it : llvm::enumerate(constraintFns))
157 constraintToMemIndex.try_emplace(it.value().first(), it.index());
158 for (auto it : llvm::enumerate(createFns))
159 nativeCreateToMemIndex.try_emplace(it.value().first(), it.index());
160 for (auto it : llvm::enumerate(rewriteFns))
161 externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
162 }
163
164 /// Generate the bytecode for the given PDL interpreter module.
165 void generate(ModuleOp module);
166
167 /// Return the memory index to use for the given value.
getMemIndex(Value value)168 ByteCodeField &getMemIndex(Value value) {
169 assert(valueToMemIndex.count(value) &&
170 "expected memory index to be assigned");
171 return valueToMemIndex[value];
172 }
173
174 /// Return an index to use when referring to the given data that is uniqued in
175 /// the MLIR context.
176 template <typename T>
177 std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
getMemIndex(T val)178 getMemIndex(T val) {
179 const void *opaqueVal = val.getAsOpaquePointer();
180
181 // Get or insert a reference to this value.
182 auto it = uniquedDataToMemIndex.try_emplace(
183 opaqueVal, maxValueMemoryIndex + uniquedData.size());
184 if (it.second)
185 uniquedData.push_back(opaqueVal);
186 return it.first->second;
187 }
188
189 private:
190 /// Allocate memory indices for the results of operations within the matcher
191 /// and rewriters.
192 void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule);
193
194 /// Generate the bytecode for the given operation.
195 void generate(Operation *op, ByteCodeWriter &writer);
196 void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
197 void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
198 void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
199 void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
200 void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
201 void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
202 void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
203 void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
204 void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
205 void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
206 void generate(pdl_interp::CreateNativeOp op, ByteCodeWriter &writer);
207 void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
208 void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
209 void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
210 void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
211 void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
212 void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
213 void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
214 void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
215 void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
216 void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
217 void generate(pdl_interp::InferredTypeOp op, ByteCodeWriter &writer);
218 void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
219 void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
220 void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
221 void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
222 void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
223 void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
224 void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
225 void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
226
227 /// Mapping from value to its corresponding memory index.
228 DenseMap<Value, ByteCodeField> valueToMemIndex;
229
230 /// Mapping from the name of an externally registered rewrite to its index in
231 /// the bytecode registry.
232 llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
233
234 /// Mapping from the name of an externally registered constraint to its index
235 /// in the bytecode registry.
236 llvm::StringMap<ByteCodeField> constraintToMemIndex;
237
238 /// Mapping from the name of an externally registered creation method to its
239 /// index in the bytecode registry.
240 llvm::StringMap<ByteCodeField> nativeCreateToMemIndex;
241
242 /// Mapping from rewriter function name to the bytecode address of the
243 /// rewriter function in byte.
244 llvm::StringMap<ByteCodeAddr> rewriterToAddr;
245
246 /// Mapping from a uniqued storage object to its memory index within
247 /// `uniquedData`.
248 DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
249
250 /// The current MLIR context.
251 MLIRContext *ctx;
252
253 /// Data of the ByteCode class to be populated.
254 std::vector<const void *> &uniquedData;
255 SmallVectorImpl<ByteCodeField> &matcherByteCode;
256 SmallVectorImpl<ByteCodeField> &rewriterByteCode;
257 SmallVectorImpl<PDLByteCodePattern> &patterns;
258 ByteCodeField &maxValueMemoryIndex;
259 };
260
261 /// This class provides utilities for writing a bytecode stream.
262 struct ByteCodeWriter {
ByteCodeWriter__anon699baea10211::ByteCodeWriter263 ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
264 : bytecode(bytecode), generator(generator) {}
265
266 /// Append a field to the bytecode.
append__anon699baea10211::ByteCodeWriter267 void append(ByteCodeField field) { bytecode.push_back(field); }
append__anon699baea10211::ByteCodeWriter268 void append(OpCode opCode) { bytecode.push_back(opCode); }
269
270 /// Append an address to the bytecode.
append__anon699baea10211::ByteCodeWriter271 void append(ByteCodeAddr field) {
272 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
273 "unexpected ByteCode address size");
274
275 ByteCodeField fieldParts[2];
276 std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
277 bytecode.append({fieldParts[0], fieldParts[1]});
278 }
279
280 /// Append a successor range to the bytecode, the exact address will need to
281 /// be resolved later.
append__anon699baea10211::ByteCodeWriter282 void append(SuccessorRange successors) {
283 // Add back references to the any successors so that the address can be
284 // resolved later.
285 for (Block *successor : successors) {
286 unresolvedSuccessorRefs[successor].push_back(bytecode.size());
287 append(ByteCodeAddr(0));
288 }
289 }
290
291 /// Append a range of values that will be read as generic PDLValues.
appendPDLValueList__anon699baea10211::ByteCodeWriter292 void appendPDLValueList(OperandRange values) {
293 bytecode.push_back(values.size());
294 for (Value value : values) {
295 // Append the type of the value in addition to the value itself.
296 PDLValueKind kind =
297 TypeSwitch<Type, PDLValueKind>(value.getType())
298 .Case<pdl::AttributeType>(
299 [](Type) { return PDLValueKind::Attribute; })
300 .Case<pdl::OperationType>(
301 [](Type) { return PDLValueKind::Operation; })
302 .Case<pdl::TypeType>([](Type) { return PDLValueKind::Type; })
303 .Case<pdl::ValueType>([](Type) { return PDLValueKind::Value; });
304 bytecode.push_back(static_cast<ByteCodeField>(kind));
305 append(value);
306 }
307 }
308
309 /// Check if the given class `T` has an iterator type.
310 template <typename T, typename... Args>
311 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
312
313 /// Append a value that will be stored in a memory slot and not inline within
314 /// the bytecode.
315 template <typename T>
316 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
317 std::is_pointer<T>::value>
append__anon699baea10211::ByteCodeWriter318 append(T value) {
319 bytecode.push_back(generator.getMemIndex(value));
320 }
321
322 /// Append a range of values.
323 template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
324 std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
append__anon699baea10211::ByteCodeWriter325 append(T range) {
326 bytecode.push_back(llvm::size(range));
327 for (auto it : range)
328 append(it);
329 }
330
331 /// Append a variadic number of fields to the bytecode.
332 template <typename FieldTy, typename Field2Ty, typename... FieldTys>
append__anon699baea10211::ByteCodeWriter333 void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
334 append(field);
335 append(field2, fields...);
336 }
337
338 /// Successor references in the bytecode that have yet to be resolved.
339 DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
340
341 /// The underlying bytecode buffer.
342 SmallVectorImpl<ByteCodeField> &bytecode;
343
344 /// The main generator producing PDL.
345 Generator &generator;
346 };
347 } // end anonymous namespace
348
generate(ModuleOp module)349 void Generator::generate(ModuleOp module) {
350 FuncOp matcherFunc = module.lookupSymbol<FuncOp>(
351 pdl_interp::PDLInterpDialect::getMatcherFunctionName());
352 ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
353 pdl_interp::PDLInterpDialect::getRewriterModuleName());
354 assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
355
356 // Allocate memory indices for the results of operations within the matcher
357 // and rewriters.
358 allocateMemoryIndices(matcherFunc, rewriterModule);
359
360 // Generate code for the rewriter functions.
361 ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
362 for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
363 rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
364 for (Operation &op : rewriterFunc.getOps())
365 generate(&op, rewriterByteCodeWriter);
366 }
367 assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
368 "unexpected branches in rewriter function");
369
370 // Generate code for the matcher function.
371 DenseMap<Block *, ByteCodeAddr> blockToAddr;
372 llvm::ReversePostOrderTraversal<Region *> rpot(&matcherFunc.getBody());
373 ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
374 for (Block *block : rpot) {
375 // Keep track of where this block begins within the matcher function.
376 blockToAddr.try_emplace(block, matcherByteCode.size());
377 for (Operation &op : *block)
378 generate(&op, matcherByteCodeWriter);
379 }
380
381 // Resolve successor references in the matcher.
382 for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
383 ByteCodeAddr addr = blockToAddr[it.first];
384 for (unsigned offsetToFix : it.second)
385 std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
386 }
387 }
388
allocateMemoryIndices(FuncOp matcherFunc,ModuleOp rewriterModule)389 void Generator::allocateMemoryIndices(FuncOp matcherFunc,
390 ModuleOp rewriterModule) {
391 // Rewriters use simplistic allocation scheme that simply assigns an index to
392 // each result.
393 for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
394 ByteCodeField index = 0;
395 for (BlockArgument arg : rewriterFunc.getArguments())
396 valueToMemIndex.try_emplace(arg, index++);
397 rewriterFunc.getBody().walk([&](Operation *op) {
398 for (Value result : op->getResults())
399 valueToMemIndex.try_emplace(result, index++);
400 });
401 if (index > maxValueMemoryIndex)
402 maxValueMemoryIndex = index;
403 }
404
405 // The matcher function uses a more sophisticated numbering that tries to
406 // minimize the number of memory indices assigned. This is done by determining
407 // a live range of the values within the matcher, then the allocation is just
408 // finding the minimal number of overlapping live ranges. This is essentially
409 // a simplified form of register allocation where we don't necessarily have a
410 // limited number of registers, but we still want to minimize the number used.
411 DenseMap<Operation *, ByteCodeField> opToIndex;
412 matcherFunc.getBody().walk([&](Operation *op) {
413 opToIndex.insert(std::make_pair(op, opToIndex.size()));
414 });
415
416 // Liveness info for each of the defs within the matcher.
417 using LivenessSet = llvm::IntervalMap<ByteCodeField, char, 16>;
418 LivenessSet::Allocator allocator;
419 DenseMap<Value, LivenessSet> valueDefRanges;
420
421 // Assign the root operation being matched to slot 0.
422 BlockArgument rootOpArg = matcherFunc.getArgument(0);
423 valueToMemIndex[rootOpArg] = 0;
424
425 // Walk each of the blocks, computing the def interval that the value is used.
426 Liveness matcherLiveness(matcherFunc);
427 for (Block &block : matcherFunc.getBody()) {
428 const LivenessBlockInfo *info = matcherLiveness.getLiveness(&block);
429 assert(info && "expected liveness info for block");
430 auto processValue = [&](Value value, Operation *firstUseOrDef) {
431 // We don't need to process the root op argument, this value is always
432 // assigned to the first memory slot.
433 if (value == rootOpArg)
434 return;
435
436 // Set indices for the range of this block that the value is used.
437 auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
438 defRangeIt->second.insert(
439 opToIndex[firstUseOrDef],
440 opToIndex[info->getEndOperation(value, firstUseOrDef)],
441 /*dummyValue*/ 0);
442 };
443
444 // Process the live-ins of this block.
445 for (Value liveIn : info->in())
446 processValue(liveIn, &block.front());
447
448 // Process any new defs within this block.
449 for (Operation &op : block)
450 for (Value result : op.getResults())
451 processValue(result, &op);
452 }
453
454 // Greedily allocate memory slots using the computed def live ranges.
455 std::vector<LivenessSet> allocatedIndices;
456 for (auto &defIt : valueDefRanges) {
457 ByteCodeField &memIndex = valueToMemIndex[defIt.first];
458 LivenessSet &defSet = defIt.second;
459
460 // Try to allocate to an existing index.
461 for (auto existingIndexIt : llvm::enumerate(allocatedIndices)) {
462 LivenessSet &existingIndex = existingIndexIt.value();
463 llvm::IntervalMapOverlaps<LivenessSet, LivenessSet> overlaps(
464 defIt.second, existingIndex);
465 if (overlaps.valid())
466 continue;
467 // Union the range of the def within the existing index.
468 for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it)
469 existingIndex.insert(it.start(), it.stop(), /*dummyValue*/ 0);
470 memIndex = existingIndexIt.index() + 1;
471 }
472
473 // If no existing index could be used, add a new one.
474 if (memIndex == 0) {
475 allocatedIndices.emplace_back(allocator);
476 for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it)
477 allocatedIndices.back().insert(it.start(), it.stop(), /*dummyValue*/ 0);
478 memIndex = allocatedIndices.size();
479 }
480 }
481
482 // Update the max number of indices.
483 ByteCodeField numMatcherIndices = allocatedIndices.size() + 1;
484 if (numMatcherIndices > maxValueMemoryIndex)
485 maxValueMemoryIndex = numMatcherIndices;
486 }
487
generate(Operation * op,ByteCodeWriter & writer)488 void Generator::generate(Operation *op, ByteCodeWriter &writer) {
489 TypeSwitch<Operation *>(op)
490 .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
491 pdl_interp::AreEqualOp, pdl_interp::BranchOp,
492 pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
493 pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
494 pdl_interp::CheckTypeOp, pdl_interp::CreateAttributeOp,
495 pdl_interp::CreateNativeOp, pdl_interp::CreateOperationOp,
496 pdl_interp::CreateTypeOp, pdl_interp::EraseOp,
497 pdl_interp::FinalizeOp, pdl_interp::GetAttributeOp,
498 pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
499 pdl_interp::GetOperandOp, pdl_interp::GetResultOp,
500 pdl_interp::GetValueTypeOp, pdl_interp::InferredTypeOp,
501 pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
502 pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
503 pdl_interp::SwitchTypeOp, pdl_interp::SwitchOperandCountOp,
504 pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>(
505 [&](auto interpOp) { this->generate(interpOp, writer); })
506 .Default([](Operation *) {
507 llvm_unreachable("unknown `pdl_interp` operation");
508 });
509 }
510
generate(pdl_interp::ApplyConstraintOp op,ByteCodeWriter & writer)511 void Generator::generate(pdl_interp::ApplyConstraintOp op,
512 ByteCodeWriter &writer) {
513 assert(constraintToMemIndex.count(op.name()) &&
514 "expected index for constraint function");
515 writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.name()],
516 op.constParamsAttr());
517 writer.appendPDLValueList(op.args());
518 writer.append(op.getSuccessors());
519 }
generate(pdl_interp::ApplyRewriteOp op,ByteCodeWriter & writer)520 void Generator::generate(pdl_interp::ApplyRewriteOp op,
521 ByteCodeWriter &writer) {
522 assert(externalRewriterToMemIndex.count(op.name()) &&
523 "expected index for rewrite function");
524 writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()],
525 op.constParamsAttr(), op.root());
526 writer.appendPDLValueList(op.args());
527 }
generate(pdl_interp::AreEqualOp op,ByteCodeWriter & writer)528 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
529 writer.append(OpCode::AreEqual, op.lhs(), op.rhs(), op.getSuccessors());
530 }
generate(pdl_interp::BranchOp op,ByteCodeWriter & writer)531 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
532 writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
533 }
generate(pdl_interp::CheckAttributeOp op,ByteCodeWriter & writer)534 void Generator::generate(pdl_interp::CheckAttributeOp op,
535 ByteCodeWriter &writer) {
536 writer.append(OpCode::AreEqual, op.attribute(), op.constantValue(),
537 op.getSuccessors());
538 }
generate(pdl_interp::CheckOperandCountOp op,ByteCodeWriter & writer)539 void Generator::generate(pdl_interp::CheckOperandCountOp op,
540 ByteCodeWriter &writer) {
541 writer.append(OpCode::CheckOperandCount, op.operation(), op.count(),
542 op.getSuccessors());
543 }
generate(pdl_interp::CheckOperationNameOp op,ByteCodeWriter & writer)544 void Generator::generate(pdl_interp::CheckOperationNameOp op,
545 ByteCodeWriter &writer) {
546 writer.append(OpCode::CheckOperationName, op.operation(),
547 OperationName(op.name(), ctx), op.getSuccessors());
548 }
generate(pdl_interp::CheckResultCountOp op,ByteCodeWriter & writer)549 void Generator::generate(pdl_interp::CheckResultCountOp op,
550 ByteCodeWriter &writer) {
551 writer.append(OpCode::CheckResultCount, op.operation(), op.count(),
552 op.getSuccessors());
553 }
generate(pdl_interp::CheckTypeOp op,ByteCodeWriter & writer)554 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
555 writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors());
556 }
generate(pdl_interp::CreateAttributeOp op,ByteCodeWriter & writer)557 void Generator::generate(pdl_interp::CreateAttributeOp op,
558 ByteCodeWriter &writer) {
559 // Simply repoint the memory index of the result to the constant.
560 getMemIndex(op.attribute()) = getMemIndex(op.value());
561 }
generate(pdl_interp::CreateNativeOp op,ByteCodeWriter & writer)562 void Generator::generate(pdl_interp::CreateNativeOp op,
563 ByteCodeWriter &writer) {
564 assert(nativeCreateToMemIndex.count(op.name()) &&
565 "expected index for creation function");
566 writer.append(OpCode::CreateNative, nativeCreateToMemIndex[op.name()],
567 op.result(), op.constParamsAttr());
568 writer.appendPDLValueList(op.args());
569 }
generate(pdl_interp::CreateOperationOp op,ByteCodeWriter & writer)570 void Generator::generate(pdl_interp::CreateOperationOp op,
571 ByteCodeWriter &writer) {
572 writer.append(OpCode::CreateOperation, op.operation(),
573 OperationName(op.name(), ctx), op.operands());
574
575 // Add the attributes.
576 OperandRange attributes = op.attributes();
577 writer.append(static_cast<ByteCodeField>(attributes.size()));
578 for (auto it : llvm::zip(op.attributeNames(), op.attributes())) {
579 writer.append(
580 Identifier::get(std::get<0>(it).cast<StringAttr>().getValue(), ctx),
581 std::get<1>(it));
582 }
583 writer.append(op.types());
584 }
generate(pdl_interp::CreateTypeOp op,ByteCodeWriter & writer)585 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
586 // Simply repoint the memory index of the result to the constant.
587 getMemIndex(op.result()) = getMemIndex(op.value());
588 }
generate(pdl_interp::EraseOp op,ByteCodeWriter & writer)589 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
590 writer.append(OpCode::EraseOp, op.operation());
591 }
generate(pdl_interp::FinalizeOp op,ByteCodeWriter & writer)592 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
593 writer.append(OpCode::Finalize);
594 }
generate(pdl_interp::GetAttributeOp op,ByteCodeWriter & writer)595 void Generator::generate(pdl_interp::GetAttributeOp op,
596 ByteCodeWriter &writer) {
597 writer.append(OpCode::GetAttribute, op.attribute(), op.operation(),
598 Identifier::get(op.name(), ctx));
599 }
generate(pdl_interp::GetAttributeTypeOp op,ByteCodeWriter & writer)600 void Generator::generate(pdl_interp::GetAttributeTypeOp op,
601 ByteCodeWriter &writer) {
602 writer.append(OpCode::GetAttributeType, op.result(), op.value());
603 }
generate(pdl_interp::GetDefiningOpOp op,ByteCodeWriter & writer)604 void Generator::generate(pdl_interp::GetDefiningOpOp op,
605 ByteCodeWriter &writer) {
606 writer.append(OpCode::GetDefiningOp, op.operation(), op.value());
607 }
generate(pdl_interp::GetOperandOp op,ByteCodeWriter & writer)608 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
609 uint32_t index = op.index();
610 if (index < 4)
611 writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
612 else
613 writer.append(OpCode::GetOperandN, index);
614 writer.append(op.operation(), op.value());
615 }
generate(pdl_interp::GetResultOp op,ByteCodeWriter & writer)616 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
617 uint32_t index = op.index();
618 if (index < 4)
619 writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
620 else
621 writer.append(OpCode::GetResultN, index);
622 writer.append(op.operation(), op.value());
623 }
generate(pdl_interp::GetValueTypeOp op,ByteCodeWriter & writer)624 void Generator::generate(pdl_interp::GetValueTypeOp op,
625 ByteCodeWriter &writer) {
626 writer.append(OpCode::GetValueType, op.result(), op.value());
627 }
generate(pdl_interp::InferredTypeOp op,ByteCodeWriter & writer)628 void Generator::generate(pdl_interp::InferredTypeOp op,
629 ByteCodeWriter &writer) {
630 // InferType maps to a null type as a marker for inferring a result type.
631 getMemIndex(op.type()) = getMemIndex(Type());
632 }
generate(pdl_interp::IsNotNullOp op,ByteCodeWriter & writer)633 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
634 writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors());
635 }
generate(pdl_interp::RecordMatchOp op,ByteCodeWriter & writer)636 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
637 ByteCodeField patternIndex = patterns.size();
638 patterns.emplace_back(PDLByteCodePattern::create(
639 op, rewriterToAddr[op.rewriter().getLeafReference()]));
640 writer.append(OpCode::RecordMatch, patternIndex,
641 SuccessorRange(op.getOperation()), op.matchedOps(),
642 op.inputs());
643 }
generate(pdl_interp::ReplaceOp op,ByteCodeWriter & writer)644 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
645 writer.append(OpCode::ReplaceOp, op.operation(), op.replValues());
646 }
generate(pdl_interp::SwitchAttributeOp op,ByteCodeWriter & writer)647 void Generator::generate(pdl_interp::SwitchAttributeOp op,
648 ByteCodeWriter &writer) {
649 writer.append(OpCode::SwitchAttribute, op.attribute(), op.caseValuesAttr(),
650 op.getSuccessors());
651 }
generate(pdl_interp::SwitchOperandCountOp op,ByteCodeWriter & writer)652 void Generator::generate(pdl_interp::SwitchOperandCountOp op,
653 ByteCodeWriter &writer) {
654 writer.append(OpCode::SwitchOperandCount, op.operation(), op.caseValuesAttr(),
655 op.getSuccessors());
656 }
generate(pdl_interp::SwitchOperationNameOp op,ByteCodeWriter & writer)657 void Generator::generate(pdl_interp::SwitchOperationNameOp op,
658 ByteCodeWriter &writer) {
659 auto cases = llvm::map_range(op.caseValuesAttr(), [&](Attribute attr) {
660 return OperationName(attr.cast<StringAttr>().getValue(), ctx);
661 });
662 writer.append(OpCode::SwitchOperationName, op.operation(), cases,
663 op.getSuccessors());
664 }
generate(pdl_interp::SwitchResultCountOp op,ByteCodeWriter & writer)665 void Generator::generate(pdl_interp::SwitchResultCountOp op,
666 ByteCodeWriter &writer) {
667 writer.append(OpCode::SwitchResultCount, op.operation(), op.caseValuesAttr(),
668 op.getSuccessors());
669 }
generate(pdl_interp::SwitchTypeOp op,ByteCodeWriter & writer)670 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
671 writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(),
672 op.getSuccessors());
673 }
674
675 //===----------------------------------------------------------------------===//
676 // PDLByteCode
677 //===----------------------------------------------------------------------===//
678
PDLByteCode(ModuleOp module,llvm::StringMap<PDLConstraintFunction> constraintFns,llvm::StringMap<PDLCreateFunction> createFns,llvm::StringMap<PDLRewriteFunction> rewriteFns)679 PDLByteCode::PDLByteCode(ModuleOp module,
680 llvm::StringMap<PDLConstraintFunction> constraintFns,
681 llvm::StringMap<PDLCreateFunction> createFns,
682 llvm::StringMap<PDLRewriteFunction> rewriteFns) {
683 Generator generator(module.getContext(), uniquedData, matcherByteCode,
684 rewriterByteCode, patterns, maxValueMemoryIndex,
685 constraintFns, createFns, rewriteFns);
686 generator.generate(module);
687
688 // Initialize the external functions.
689 for (auto &it : constraintFns)
690 constraintFunctions.push_back(std::move(it.second));
691 for (auto &it : createFns)
692 createFunctions.push_back(std::move(it.second));
693 for (auto &it : rewriteFns)
694 rewriteFunctions.push_back(std::move(it.second));
695 }
696
697 /// Initialize the given state such that it can be used to execute the current
698 /// bytecode.
initializeMutableState(PDLByteCodeMutableState & state) const699 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
700 state.memory.resize(maxValueMemoryIndex, nullptr);
701 state.currentPatternBenefits.reserve(patterns.size());
702 for (const PDLByteCodePattern &pattern : patterns)
703 state.currentPatternBenefits.push_back(pattern.getBenefit());
704 }
705
706 //===----------------------------------------------------------------------===//
707 // ByteCode Execution
708
709 namespace {
710 /// This class provides support for executing a bytecode stream.
711 class ByteCodeExecutor {
712 public:
ByteCodeExecutor(const ByteCodeField * curCodeIt,MutableArrayRef<const void * > memory,ArrayRef<const void * > uniquedMemory,ArrayRef<ByteCodeField> code,ArrayRef<PatternBenefit> currentPatternBenefits,ArrayRef<PDLByteCodePattern> patterns,ArrayRef<PDLConstraintFunction> constraintFunctions,ArrayRef<PDLCreateFunction> createFunctions,ArrayRef<PDLRewriteFunction> rewriteFunctions)713 ByteCodeExecutor(const ByteCodeField *curCodeIt,
714 MutableArrayRef<const void *> memory,
715 ArrayRef<const void *> uniquedMemory,
716 ArrayRef<ByteCodeField> code,
717 ArrayRef<PatternBenefit> currentPatternBenefits,
718 ArrayRef<PDLByteCodePattern> patterns,
719 ArrayRef<PDLConstraintFunction> constraintFunctions,
720 ArrayRef<PDLCreateFunction> createFunctions,
721 ArrayRef<PDLRewriteFunction> rewriteFunctions)
722 : curCodeIt(curCodeIt), memory(memory), uniquedMemory(uniquedMemory),
723 code(code), currentPatternBenefits(currentPatternBenefits),
724 patterns(patterns), constraintFunctions(constraintFunctions),
725 createFunctions(createFunctions), rewriteFunctions(rewriteFunctions) {}
726
727 /// Start executing the code at the current bytecode index. `matches` is an
728 /// optional field provided when this function is executed in a matching
729 /// context.
730 void execute(PatternRewriter &rewriter,
731 SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
732 Optional<Location> mainRewriteLoc = {});
733
734 private:
735 /// Read a value from the bytecode buffer, optionally skipping a certain
736 /// number of prefix values. These methods always update the buffer to point
737 /// to the next field after the read data.
738 template <typename T = ByteCodeField>
read(size_t skipN=0)739 T read(size_t skipN = 0) {
740 curCodeIt += skipN;
741 return readImpl<T>();
742 }
read(size_t skipN=0)743 ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
744
745 /// Read a list of values from the bytecode buffer.
746 template <typename ValueT, typename T>
readList(SmallVectorImpl<T> & list)747 void readList(SmallVectorImpl<T> &list) {
748 list.clear();
749 for (unsigned i = 0, e = read(); i != e; ++i)
750 list.push_back(read<ValueT>());
751 }
752
753 /// Jump to a specific successor based on a predicate value.
selectJump(bool isTrue)754 void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
755 /// Jump to a specific successor based on a destination index.
selectJump(size_t destIndex)756 void selectJump(size_t destIndex) {
757 curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
758 }
759
760 /// Handle a switch operation with the provided value and cases.
761 template <typename T, typename RangeT>
handleSwitch(const T & value,RangeT && cases)762 void handleSwitch(const T &value, RangeT &&cases) {
763 LLVM_DEBUG({
764 llvm::dbgs() << " * Value: " << value << "\n"
765 << " * Cases: ";
766 llvm::interleaveComma(cases, llvm::dbgs());
767 llvm::dbgs() << "\n\n";
768 });
769
770 // Check to see if the attribute value is within the case list. Jump to
771 // the correct successor index based on the result.
772 for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
773 if (*it == value)
774 return selectJump(size_t((it - cases.begin()) + 1));
775 selectJump(size_t(0));
776 }
777
778 /// Internal implementation of reading various data types from the bytecode
779 /// stream.
780 template <typename T>
readFromMemory()781 const void *readFromMemory() {
782 size_t index = *curCodeIt++;
783
784 // If this type is an SSA value, it can only be stored in non-const memory.
785 if (llvm::is_one_of<T, Operation *, Value>::value || index < memory.size())
786 return memory[index];
787
788 // Otherwise, if this index is not inbounds it is uniqued.
789 return uniquedMemory[index - memory.size()];
790 }
791 template <typename T>
readImpl()792 std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
793 return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
794 }
795 template <typename T>
796 std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
797 T>
readImpl()798 readImpl() {
799 return T(T::getFromOpaquePointer(readFromMemory<T>()));
800 }
801 template <typename T>
readImpl()802 std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
803 switch (static_cast<PDLValueKind>(read())) {
804 case PDLValueKind::Attribute:
805 return read<Attribute>();
806 case PDLValueKind::Operation:
807 return read<Operation *>();
808 case PDLValueKind::Type:
809 return read<Type>();
810 case PDLValueKind::Value:
811 return read<Value>();
812 }
813 llvm_unreachable("unhandled PDLValueKind");
814 }
815 template <typename T>
readImpl()816 std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
817 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
818 "unexpected ByteCode address size");
819 ByteCodeAddr result;
820 std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
821 curCodeIt += 2;
822 return result;
823 }
824 template <typename T>
readImpl()825 std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
826 return *curCodeIt++;
827 }
828
829 /// The underlying bytecode buffer.
830 const ByteCodeField *curCodeIt;
831
832 /// The current execution memory.
833 MutableArrayRef<const void *> memory;
834
835 /// References to ByteCode data necessary for execution.
836 ArrayRef<const void *> uniquedMemory;
837 ArrayRef<ByteCodeField> code;
838 ArrayRef<PatternBenefit> currentPatternBenefits;
839 ArrayRef<PDLByteCodePattern> patterns;
840 ArrayRef<PDLConstraintFunction> constraintFunctions;
841 ArrayRef<PDLCreateFunction> createFunctions;
842 ArrayRef<PDLRewriteFunction> rewriteFunctions;
843 };
844 } // end anonymous namespace
845
execute(PatternRewriter & rewriter,SmallVectorImpl<PDLByteCode::MatchResult> * matches,Optional<Location> mainRewriteLoc)846 void ByteCodeExecutor::execute(
847 PatternRewriter &rewriter,
848 SmallVectorImpl<PDLByteCode::MatchResult> *matches,
849 Optional<Location> mainRewriteLoc) {
850 while (true) {
851 OpCode opCode = static_cast<OpCode>(read());
852 switch (opCode) {
853 case ApplyConstraint: {
854 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
855 const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
856 ArrayAttr constParams = read<ArrayAttr>();
857 SmallVector<PDLValue, 16> args;
858 readList<PDLValue>(args);
859 LLVM_DEBUG({
860 llvm::dbgs() << " * Arguments: ";
861 llvm::interleaveComma(args, llvm::dbgs());
862 llvm::dbgs() << "\n * Parameters: " << constParams << "\n\n";
863 });
864
865 // Invoke the constraint and jump to the proper destination.
866 selectJump(succeeded(constraintFn(args, constParams, rewriter)));
867 break;
868 }
869 case ApplyRewrite: {
870 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
871 const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
872 ArrayAttr constParams = read<ArrayAttr>();
873 Operation *root = read<Operation *>();
874 SmallVector<PDLValue, 16> args;
875 readList<PDLValue>(args);
876
877 LLVM_DEBUG({
878 llvm::dbgs() << " * Root: " << *root << "\n"
879 << " * Arguments: ";
880 llvm::interleaveComma(args, llvm::dbgs());
881 llvm::dbgs() << "\n * Parameters: " << constParams << "\n\n";
882 });
883 rewriteFn(root, args, constParams, rewriter);
884 break;
885 }
886 case AreEqual: {
887 LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
888 const void *lhs = read<const void *>();
889 const void *rhs = read<const void *>();
890
891 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
892 selectJump(lhs == rhs);
893 break;
894 }
895 case Branch: {
896 LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n\n");
897 curCodeIt = &code[read<ByteCodeAddr>()];
898 break;
899 }
900 case CheckOperandCount: {
901 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
902 Operation *op = read<Operation *>();
903 uint32_t expectedCount = read<uint32_t>();
904
905 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n"
906 << " * Expected: " << expectedCount << "\n\n");
907 selectJump(op->getNumOperands() == expectedCount);
908 break;
909 }
910 case CheckOperationName: {
911 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
912 Operation *op = read<Operation *>();
913 OperationName expectedName = read<OperationName>();
914
915 LLVM_DEBUG(llvm::dbgs()
916 << " * Found: \"" << op->getName() << "\"\n"
917 << " * Expected: \"" << expectedName << "\"\n\n");
918 selectJump(op->getName() == expectedName);
919 break;
920 }
921 case CheckResultCount: {
922 LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
923 Operation *op = read<Operation *>();
924 uint32_t expectedCount = read<uint32_t>();
925
926 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n"
927 << " * Expected: " << expectedCount << "\n\n");
928 selectJump(op->getNumResults() == expectedCount);
929 break;
930 }
931 case CreateNative: {
932 LLVM_DEBUG(llvm::dbgs() << "Executing CreateNative:\n");
933 const PDLCreateFunction &createFn = createFunctions[read()];
934 ByteCodeField resultIndex = read();
935 ArrayAttr constParams = read<ArrayAttr>();
936 SmallVector<PDLValue, 16> args;
937 readList<PDLValue>(args);
938
939 LLVM_DEBUG({
940 llvm::dbgs() << " * Arguments: ";
941 llvm::interleaveComma(args, llvm::dbgs());
942 llvm::dbgs() << "\n * Parameters: " << constParams << "\n";
943 });
944
945 PDLValue result = createFn(args, constParams, rewriter);
946 memory[resultIndex] = result.getAsOpaquePointer();
947
948 LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n\n");
949 break;
950 }
951 case CreateOperation: {
952 LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
953 assert(mainRewriteLoc && "expected rewrite loc to be provided when "
954 "executing the rewriter bytecode");
955
956 unsigned memIndex = read();
957 OperationState state(*mainRewriteLoc, read<OperationName>());
958 readList<Value>(state.operands);
959 for (unsigned i = 0, e = read(); i != e; ++i) {
960 Identifier name = read<Identifier>();
961 if (Attribute attr = read<Attribute>())
962 state.addAttribute(name, attr);
963 }
964
965 bool hasInferredTypes = false;
966 for (unsigned i = 0, e = read(); i != e; ++i) {
967 Type resultType = read<Type>();
968 hasInferredTypes |= !resultType;
969 state.types.push_back(resultType);
970 }
971
972 // Handle the case where the operation has inferred types.
973 if (hasInferredTypes) {
974 InferTypeOpInterface::Concept *concept =
975 state.name.getAbstractOperation()
976 ->getInterface<InferTypeOpInterface>();
977
978 // TODO: Handle failure.
979 SmallVector<Type, 2> inferredTypes;
980 if (failed(concept->inferReturnTypes(
981 state.getContext(), state.location, state.operands,
982 state.attributes.getDictionary(state.getContext()),
983 state.regions, inferredTypes)))
984 return;
985
986 for (unsigned i = 0, e = state.types.size(); i != e; ++i)
987 if (!state.types[i])
988 state.types[i] = inferredTypes[i];
989 }
990 Operation *resultOp = rewriter.createOperation(state);
991 memory[memIndex] = resultOp;
992
993 LLVM_DEBUG({
994 llvm::dbgs() << " * Attributes: "
995 << state.attributes.getDictionary(state.getContext())
996 << "\n * Operands: ";
997 llvm::interleaveComma(state.operands, llvm::dbgs());
998 llvm::dbgs() << "\n * Result Types: ";
999 llvm::interleaveComma(state.types, llvm::dbgs());
1000 llvm::dbgs() << "\n * Result: " << *resultOp << "\n\n";
1001 });
1002 break;
1003 }
1004 case EraseOp: {
1005 LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
1006 Operation *op = read<Operation *>();
1007
1008 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n\n");
1009 rewriter.eraseOp(op);
1010 break;
1011 }
1012 case Finalize: {
1013 LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n\n");
1014 return;
1015 }
1016 case GetAttribute: {
1017 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
1018 unsigned memIndex = read();
1019 Operation *op = read<Operation *>();
1020 Identifier attrName = read<Identifier>();
1021 Attribute attr = op->getAttr(attrName);
1022
1023 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
1024 << " * Attribute: " << attrName << "\n"
1025 << " * Result: " << attr << "\n\n");
1026 memory[memIndex] = attr.getAsOpaquePointer();
1027 break;
1028 }
1029 case GetAttributeType: {
1030 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
1031 unsigned memIndex = read();
1032 Attribute attr = read<Attribute>();
1033
1034 LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n"
1035 << " * Result: " << attr.getType() << "\n\n");
1036 memory[memIndex] = attr.getType().getAsOpaquePointer();
1037 break;
1038 }
1039 case GetDefiningOp: {
1040 LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
1041 unsigned memIndex = read();
1042 Value value = read<Value>();
1043 Operation *op = value ? value.getDefiningOp() : nullptr;
1044
1045 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"
1046 << " * Result: " << *op << "\n\n");
1047 memory[memIndex] = op;
1048 break;
1049 }
1050 case GetOperand0:
1051 case GetOperand1:
1052 case GetOperand2:
1053 case GetOperand3:
1054 case GetOperandN: {
1055 LLVM_DEBUG({
1056 llvm::dbgs() << "Executing GetOperand"
1057 << (opCode == GetOperandN ? Twine("N")
1058 : Twine(opCode - GetOperand0))
1059 << ":\n";
1060 });
1061 unsigned index =
1062 opCode == GetOperandN ? read<uint32_t>() : (opCode - GetOperand0);
1063 Operation *op = read<Operation *>();
1064 unsigned memIndex = read();
1065 Value operand =
1066 index < op->getNumOperands() ? op->getOperand(index) : Value();
1067
1068 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
1069 << " * Index: " << index << "\n"
1070 << " * Result: " << operand << "\n\n");
1071 memory[memIndex] = operand.getAsOpaquePointer();
1072 break;
1073 }
1074 case GetResult0:
1075 case GetResult1:
1076 case GetResult2:
1077 case GetResult3:
1078 case GetResultN: {
1079 LLVM_DEBUG({
1080 llvm::dbgs() << "Executing GetResult"
1081 << (opCode == GetResultN ? Twine("N")
1082 : Twine(opCode - GetResult0))
1083 << ":\n";
1084 });
1085 unsigned index =
1086 opCode == GetResultN ? read<uint32_t>() : (opCode - GetResult0);
1087 Operation *op = read<Operation *>();
1088 unsigned memIndex = read();
1089 OpResult result =
1090 index < op->getNumResults() ? op->getResult(index) : OpResult();
1091
1092 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
1093 << " * Index: " << index << "\n"
1094 << " * Result: " << result << "\n\n");
1095 memory[memIndex] = result.getAsOpaquePointer();
1096 break;
1097 }
1098 case GetValueType: {
1099 LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
1100 unsigned memIndex = read();
1101 Value value = read<Value>();
1102
1103 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"
1104 << " * Result: " << value.getType() << "\n\n");
1105 memory[memIndex] = value.getType().getAsOpaquePointer();
1106 break;
1107 }
1108 case IsNotNull: {
1109 LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
1110 const void *value = read<const void *>();
1111
1112 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n\n");
1113 selectJump(value != nullptr);
1114 break;
1115 }
1116 case RecordMatch: {
1117 LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
1118 assert(matches &&
1119 "expected matches to be provided when executing the matcher");
1120 unsigned patternIndex = read();
1121 PatternBenefit benefit = currentPatternBenefits[patternIndex];
1122 const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
1123
1124 // If the benefit of the pattern is impossible, skip the processing of the
1125 // rest of the pattern.
1126 if (benefit.isImpossibleToMatch()) {
1127 LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n\n");
1128 curCodeIt = dest;
1129 break;
1130 }
1131
1132 // Create a fused location containing the locations of each of the
1133 // operations used in the match. This will be used as the location for
1134 // created operations during the rewrite that don't already have an
1135 // explicit location set.
1136 unsigned numMatchLocs = read();
1137 SmallVector<Location, 4> matchLocs;
1138 matchLocs.reserve(numMatchLocs);
1139 for (unsigned i = 0; i != numMatchLocs; ++i)
1140 matchLocs.push_back(read<Operation *>()->getLoc());
1141 Location matchLoc = rewriter.getFusedLoc(matchLocs);
1142
1143 LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n"
1144 << " * Location: " << matchLoc << "\n\n");
1145 matches->emplace_back(matchLoc, patterns[patternIndex], benefit);
1146 readList<const void *>(matches->back().values);
1147 curCodeIt = dest;
1148 break;
1149 }
1150 case ReplaceOp: {
1151 LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
1152 Operation *op = read<Operation *>();
1153 SmallVector<Value, 16> args;
1154 readList<Value>(args);
1155
1156 LLVM_DEBUG({
1157 llvm::dbgs() << " * Operation: " << *op << "\n"
1158 << " * Values: ";
1159 llvm::interleaveComma(args, llvm::dbgs());
1160 llvm::dbgs() << "\n\n";
1161 });
1162 rewriter.replaceOp(op, args);
1163 break;
1164 }
1165 case SwitchAttribute: {
1166 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
1167 Attribute value = read<Attribute>();
1168 ArrayAttr cases = read<ArrayAttr>();
1169 handleSwitch(value, cases);
1170 break;
1171 }
1172 case SwitchOperandCount: {
1173 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
1174 Operation *op = read<Operation *>();
1175 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1176
1177 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
1178 handleSwitch(op->getNumOperands(), cases);
1179 break;
1180 }
1181 case SwitchOperationName: {
1182 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
1183 OperationName value = read<Operation *>()->getName();
1184 size_t caseCount = read();
1185
1186 // The operation names are stored in-line, so to print them out for
1187 // debugging purposes we need to read the array before executing the
1188 // switch so that we can display all of the possible values.
1189 LLVM_DEBUG({
1190 const ByteCodeField *prevCodeIt = curCodeIt;
1191 llvm::dbgs() << " * Value: " << value << "\n"
1192 << " * Cases: ";
1193 llvm::interleaveComma(
1194 llvm::map_range(llvm::seq<size_t>(0, caseCount),
1195 [&](size_t i) { return read<OperationName>(); }),
1196 llvm::dbgs());
1197 llvm::dbgs() << "\n\n";
1198 curCodeIt = prevCodeIt;
1199 });
1200
1201 // Try to find the switch value within any of the cases.
1202 size_t jumpDest = 0;
1203 for (size_t i = 0; i != caseCount; ++i) {
1204 if (read<OperationName>() == value) {
1205 curCodeIt += (caseCount - i - 1);
1206 jumpDest = i + 1;
1207 break;
1208 }
1209 }
1210 selectJump(jumpDest);
1211 break;
1212 }
1213 case SwitchResultCount: {
1214 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
1215 Operation *op = read<Operation *>();
1216 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1217
1218 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
1219 handleSwitch(op->getNumResults(), cases);
1220 break;
1221 }
1222 case SwitchType: {
1223 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
1224 Type value = read<Type>();
1225 auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
1226 handleSwitch(value, cases);
1227 break;
1228 }
1229 }
1230 }
1231 }
1232
1233 /// Run the pattern matcher on the given root operation, collecting the matched
1234 /// patterns in `matches`.
match(Operation * op,PatternRewriter & rewriter,SmallVectorImpl<MatchResult> & matches,PDLByteCodeMutableState & state) const1235 void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
1236 SmallVectorImpl<MatchResult> &matches,
1237 PDLByteCodeMutableState &state) const {
1238 // The first memory slot is always the root operation.
1239 state.memory[0] = op;
1240
1241 // The matcher function always starts at code address 0.
1242 ByteCodeExecutor executor(matcherByteCode.data(), state.memory, uniquedData,
1243 matcherByteCode, state.currentPatternBenefits,
1244 patterns, constraintFunctions, createFunctions,
1245 rewriteFunctions);
1246 executor.execute(rewriter, &matches);
1247
1248 // Order the found matches by benefit.
1249 std::stable_sort(matches.begin(), matches.end(),
1250 [](const MatchResult &lhs, const MatchResult &rhs) {
1251 return lhs.benefit > rhs.benefit;
1252 });
1253 }
1254
1255 /// Run the rewriter of the given pattern on the root operation `op`.
rewrite(PatternRewriter & rewriter,const MatchResult & match,PDLByteCodeMutableState & state) const1256 void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match,
1257 PDLByteCodeMutableState &state) const {
1258 // The arguments of the rewrite function are stored at the start of the
1259 // memory buffer.
1260 llvm::copy(match.values, state.memory.begin());
1261
1262 ByteCodeExecutor executor(
1263 &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
1264 uniquedData, rewriterByteCode, state.currentPatternBenefits, patterns,
1265 constraintFunctions, createFunctions, rewriteFunctions);
1266 executor.execute(rewriter, /*matches=*/nullptr, match.location);
1267 }
1268