1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements MLIR to byte-code generation and the interpreter.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "ByteCode.h"
14 #include "mlir/Analysis/Liveness.h"
15 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
16 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/RegionGraphTraits.h"
19 #include "llvm/ADT/IntervalMap.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/Format.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include <numeric>
26
27 #define DEBUG_TYPE "pdl-bytecode"
28
29 using namespace mlir;
30 using namespace mlir::detail;
31
32 //===----------------------------------------------------------------------===//
33 // PDLByteCodePattern
34 //===----------------------------------------------------------------------===//
35
create(pdl_interp::RecordMatchOp matchOp,ByteCodeAddr rewriterAddr)36 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
37 ByteCodeAddr rewriterAddr) {
38 SmallVector<StringRef, 8> generatedOps;
39 if (ArrayAttr generatedOpsAttr = matchOp.generatedOpsAttr())
40 generatedOps =
41 llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
42
43 PatternBenefit benefit = matchOp.benefit();
44 MLIRContext *ctx = matchOp.getContext();
45
46 // Check to see if this is pattern matches a specific operation type.
47 if (Optional<StringRef> rootKind = matchOp.rootKind())
48 return PDLByteCodePattern(rewriterAddr, *rootKind, benefit, ctx,
49 generatedOps);
50 return PDLByteCodePattern(rewriterAddr, MatchAnyOpTypeTag(), benefit, ctx,
51 generatedOps);
52 }
53
54 //===----------------------------------------------------------------------===//
55 // PDLByteCodeMutableState
56 //===----------------------------------------------------------------------===//
57
58 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
59 /// to the position of the pattern within the range returned by
60 /// `PDLByteCode::getPatterns`.
updatePatternBenefit(unsigned patternIndex,PatternBenefit benefit)61 void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
62 PatternBenefit benefit) {
63 currentPatternBenefits[patternIndex] = benefit;
64 }
65
66 /// Cleanup any allocated state after a full match/rewrite has been completed.
67 /// This method should be called irregardless of whether the match+rewrite was a
68 /// success or not.
cleanupAfterMatchAndRewrite()69 void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() {
70 allocatedTypeRangeMemory.clear();
71 allocatedValueRangeMemory.clear();
72 }
73
74 //===----------------------------------------------------------------------===//
75 // Bytecode OpCodes
76 //===----------------------------------------------------------------------===//
77
78 namespace {
79 enum OpCode : ByteCodeField {
80 /// Apply an externally registered constraint.
81 ApplyConstraint,
82 /// Apply an externally registered rewrite.
83 ApplyRewrite,
84 /// Check if two generic values are equal.
85 AreEqual,
86 /// Check if two ranges are equal.
87 AreRangesEqual,
88 /// Unconditional branch.
89 Branch,
90 /// Compare the operand count of an operation with a constant.
91 CheckOperandCount,
92 /// Compare the name of an operation with a constant.
93 CheckOperationName,
94 /// Compare the result count of an operation with a constant.
95 CheckResultCount,
96 /// Compare a range of types to a constant range of types.
97 CheckTypes,
98 /// Create an operation.
99 CreateOperation,
100 /// Create a range of types.
101 CreateTypes,
102 /// Erase an operation.
103 EraseOp,
104 /// Terminate a matcher or rewrite sequence.
105 Finalize,
106 /// Get a specific attribute of an operation.
107 GetAttribute,
108 /// Get the type of an attribute.
109 GetAttributeType,
110 /// Get the defining operation of a value.
111 GetDefiningOp,
112 /// Get a specific operand of an operation.
113 GetOperand0,
114 GetOperand1,
115 GetOperand2,
116 GetOperand3,
117 GetOperandN,
118 /// Get a specific operand group of an operation.
119 GetOperands,
120 /// Get a specific result of an operation.
121 GetResult0,
122 GetResult1,
123 GetResult2,
124 GetResult3,
125 GetResultN,
126 /// Get a specific result group of an operation.
127 GetResults,
128 /// Get the type of a value.
129 GetValueType,
130 /// Get the types of a value range.
131 GetValueRangeTypes,
132 /// Check if a generic value is not null.
133 IsNotNull,
134 /// Record a successful pattern match.
135 RecordMatch,
136 /// Replace an operation.
137 ReplaceOp,
138 /// Compare an attribute with a set of constants.
139 SwitchAttribute,
140 /// Compare the operand count of an operation with a set of constants.
141 SwitchOperandCount,
142 /// Compare the name of an operation with a set of constants.
143 SwitchOperationName,
144 /// Compare the result count of an operation with a set of constants.
145 SwitchResultCount,
146 /// Compare a type with a set of constants.
147 SwitchType,
148 /// Compare a range of types with a set of constants.
149 SwitchTypes,
150 };
151 } // end anonymous namespace
152
153 //===----------------------------------------------------------------------===//
154 // ByteCode Generation
155 //===----------------------------------------------------------------------===//
156
157 //===----------------------------------------------------------------------===//
158 // Generator
159
160 namespace {
161 struct ByteCodeWriter;
162
163 /// This class represents the main generator for the pattern bytecode.
164 class Generator {
165 public:
Generator(MLIRContext * ctx,std::vector<const void * > & uniquedData,SmallVectorImpl<ByteCodeField> & matcherByteCode,SmallVectorImpl<ByteCodeField> & rewriterByteCode,SmallVectorImpl<PDLByteCodePattern> & patterns,ByteCodeField & maxValueMemoryIndex,ByteCodeField & maxTypeRangeMemoryIndex,ByteCodeField & maxValueRangeMemoryIndex,llvm::StringMap<PDLConstraintFunction> & constraintFns,llvm::StringMap<PDLRewriteFunction> & rewriteFns)166 Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
167 SmallVectorImpl<ByteCodeField> &matcherByteCode,
168 SmallVectorImpl<ByteCodeField> &rewriterByteCode,
169 SmallVectorImpl<PDLByteCodePattern> &patterns,
170 ByteCodeField &maxValueMemoryIndex,
171 ByteCodeField &maxTypeRangeMemoryIndex,
172 ByteCodeField &maxValueRangeMemoryIndex,
173 llvm::StringMap<PDLConstraintFunction> &constraintFns,
174 llvm::StringMap<PDLRewriteFunction> &rewriteFns)
175 : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
176 rewriterByteCode(rewriterByteCode), patterns(patterns),
177 maxValueMemoryIndex(maxValueMemoryIndex),
178 maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
179 maxValueRangeMemoryIndex(maxValueRangeMemoryIndex) {
180 for (auto it : llvm::enumerate(constraintFns))
181 constraintToMemIndex.try_emplace(it.value().first(), it.index());
182 for (auto it : llvm::enumerate(rewriteFns))
183 externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
184 }
185
186 /// Generate the bytecode for the given PDL interpreter module.
187 void generate(ModuleOp module);
188
189 /// Return the memory index to use for the given value.
getMemIndex(Value value)190 ByteCodeField &getMemIndex(Value value) {
191 assert(valueToMemIndex.count(value) &&
192 "expected memory index to be assigned");
193 return valueToMemIndex[value];
194 }
195
196 /// Return the range memory index used to store the given range value.
getRangeStorageIndex(Value value)197 ByteCodeField &getRangeStorageIndex(Value value) {
198 assert(valueToRangeIndex.count(value) &&
199 "expected range index to be assigned");
200 return valueToRangeIndex[value];
201 }
202
203 /// Return an index to use when referring to the given data that is uniqued in
204 /// the MLIR context.
205 template <typename T>
206 std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
getMemIndex(T val)207 getMemIndex(T val) {
208 const void *opaqueVal = val.getAsOpaquePointer();
209
210 // Get or insert a reference to this value.
211 auto it = uniquedDataToMemIndex.try_emplace(
212 opaqueVal, maxValueMemoryIndex + uniquedData.size());
213 if (it.second)
214 uniquedData.push_back(opaqueVal);
215 return it.first->second;
216 }
217
218 private:
219 /// Allocate memory indices for the results of operations within the matcher
220 /// and rewriters.
221 void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule);
222
223 /// Generate the bytecode for the given operation.
224 void generate(Operation *op, ByteCodeWriter &writer);
225 void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
226 void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
227 void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
228 void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
229 void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
230 void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
231 void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
232 void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
233 void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
234 void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
235 void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
236 void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
237 void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
238 void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
239 void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
240 void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
241 void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
242 void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
243 void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
244 void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
245 void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
246 void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
247 void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
248 void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
249 void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer);
250 void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
251 void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
252 void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
253 void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
254 void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
255 void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
256 void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
257 void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
258 void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
259
260 /// Mapping from value to its corresponding memory index.
261 DenseMap<Value, ByteCodeField> valueToMemIndex;
262
263 /// Mapping from a range value to its corresponding range storage index.
264 DenseMap<Value, ByteCodeField> valueToRangeIndex;
265
266 /// Mapping from the name of an externally registered rewrite to its index in
267 /// the bytecode registry.
268 llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
269
270 /// Mapping from the name of an externally registered constraint to its index
271 /// in the bytecode registry.
272 llvm::StringMap<ByteCodeField> constraintToMemIndex;
273
274 /// Mapping from rewriter function name to the bytecode address of the
275 /// rewriter function in byte.
276 llvm::StringMap<ByteCodeAddr> rewriterToAddr;
277
278 /// Mapping from a uniqued storage object to its memory index within
279 /// `uniquedData`.
280 DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
281
282 /// The current MLIR context.
283 MLIRContext *ctx;
284
285 /// Data of the ByteCode class to be populated.
286 std::vector<const void *> &uniquedData;
287 SmallVectorImpl<ByteCodeField> &matcherByteCode;
288 SmallVectorImpl<ByteCodeField> &rewriterByteCode;
289 SmallVectorImpl<PDLByteCodePattern> &patterns;
290 ByteCodeField &maxValueMemoryIndex;
291 ByteCodeField &maxTypeRangeMemoryIndex;
292 ByteCodeField &maxValueRangeMemoryIndex;
293 };
294
295 /// This class provides utilities for writing a bytecode stream.
296 struct ByteCodeWriter {
ByteCodeWriter__anon06213e220211::ByteCodeWriter297 ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
298 : bytecode(bytecode), generator(generator) {}
299
300 /// Append a field to the bytecode.
append__anon06213e220211::ByteCodeWriter301 void append(ByteCodeField field) { bytecode.push_back(field); }
append__anon06213e220211::ByteCodeWriter302 void append(OpCode opCode) { bytecode.push_back(opCode); }
303
304 /// Append an address to the bytecode.
append__anon06213e220211::ByteCodeWriter305 void append(ByteCodeAddr field) {
306 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
307 "unexpected ByteCode address size");
308
309 ByteCodeField fieldParts[2];
310 std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
311 bytecode.append({fieldParts[0], fieldParts[1]});
312 }
313
314 /// Append a successor range to the bytecode, the exact address will need to
315 /// be resolved later.
append__anon06213e220211::ByteCodeWriter316 void append(SuccessorRange successors) {
317 // Add back references to the any successors so that the address can be
318 // resolved later.
319 for (Block *successor : successors) {
320 unresolvedSuccessorRefs[successor].push_back(bytecode.size());
321 append(ByteCodeAddr(0));
322 }
323 }
324
325 /// Append a range of values that will be read as generic PDLValues.
appendPDLValueList__anon06213e220211::ByteCodeWriter326 void appendPDLValueList(OperandRange values) {
327 bytecode.push_back(values.size());
328 for (Value value : values)
329 appendPDLValue(value);
330 }
331
332 /// Append a value as a PDLValue.
appendPDLValue__anon06213e220211::ByteCodeWriter333 void appendPDLValue(Value value) {
334 appendPDLValueKind(value);
335 append(value);
336 }
337
338 /// Append the PDLValue::Kind of the given value.
appendPDLValueKind__anon06213e220211::ByteCodeWriter339 void appendPDLValueKind(Value value) {
340 // Append the type of the value in addition to the value itself.
341 PDLValue::Kind kind =
342 TypeSwitch<Type, PDLValue::Kind>(value.getType())
343 .Case<pdl::AttributeType>(
344 [](Type) { return PDLValue::Kind::Attribute; })
345 .Case<pdl::OperationType>(
346 [](Type) { return PDLValue::Kind::Operation; })
347 .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
348 if (rangeTy.getElementType().isa<pdl::TypeType>())
349 return PDLValue::Kind::TypeRange;
350 return PDLValue::Kind::ValueRange;
351 })
352 .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; })
353 .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; });
354 bytecode.push_back(static_cast<ByteCodeField>(kind));
355 }
356
357 /// Check if the given class `T` has an iterator type.
358 template <typename T, typename... Args>
359 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
360
361 /// Append a value that will be stored in a memory slot and not inline within
362 /// the bytecode.
363 template <typename T>
364 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
365 std::is_pointer<T>::value>
append__anon06213e220211::ByteCodeWriter366 append(T value) {
367 bytecode.push_back(generator.getMemIndex(value));
368 }
369
370 /// Append a range of values.
371 template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
372 std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
append__anon06213e220211::ByteCodeWriter373 append(T range) {
374 bytecode.push_back(llvm::size(range));
375 for (auto it : range)
376 append(it);
377 }
378
379 /// Append a variadic number of fields to the bytecode.
380 template <typename FieldTy, typename Field2Ty, typename... FieldTys>
append__anon06213e220211::ByteCodeWriter381 void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
382 append(field);
383 append(field2, fields...);
384 }
385
386 /// Successor references in the bytecode that have yet to be resolved.
387 DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
388
389 /// The underlying bytecode buffer.
390 SmallVectorImpl<ByteCodeField> &bytecode;
391
392 /// The main generator producing PDL.
393 Generator &generator;
394 };
395
396 /// This class represents a live range of PDL Interpreter values, containing
397 /// information about when values are live within a match/rewrite.
398 struct ByteCodeLiveRange {
399 using Set = llvm::IntervalMap<ByteCodeField, char, 16>;
400 using Allocator = Set::Allocator;
401
ByteCodeLiveRange__anon06213e220211::ByteCodeLiveRange402 ByteCodeLiveRange(Allocator &alloc) : liveness(alloc) {}
403
404 /// Union this live range with the one provided.
unionWith__anon06213e220211::ByteCodeLiveRange405 void unionWith(const ByteCodeLiveRange &rhs) {
406 for (auto it = rhs.liveness.begin(), e = rhs.liveness.end(); it != e; ++it)
407 liveness.insert(it.start(), it.stop(), /*dummyValue*/ 0);
408 }
409
410 /// Returns true if this range overlaps with the one provided.
overlaps__anon06213e220211::ByteCodeLiveRange411 bool overlaps(const ByteCodeLiveRange &rhs) const {
412 return llvm::IntervalMapOverlaps<Set, Set>(liveness, rhs.liveness).valid();
413 }
414
415 /// A map representing the ranges of the match/rewrite that a value is live in
416 /// the interpreter.
417 llvm::IntervalMap<ByteCodeField, char, 16> liveness;
418
419 /// The type range storage index for this range.
420 Optional<unsigned> typeRangeIndex;
421
422 /// The value range storage index for this range.
423 Optional<unsigned> valueRangeIndex;
424 };
425 } // end anonymous namespace
426
generate(ModuleOp module)427 void Generator::generate(ModuleOp module) {
428 FuncOp matcherFunc = module.lookupSymbol<FuncOp>(
429 pdl_interp::PDLInterpDialect::getMatcherFunctionName());
430 ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
431 pdl_interp::PDLInterpDialect::getRewriterModuleName());
432 assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
433
434 // Allocate memory indices for the results of operations within the matcher
435 // and rewriters.
436 allocateMemoryIndices(matcherFunc, rewriterModule);
437
438 // Generate code for the rewriter functions.
439 ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
440 for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
441 rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
442 for (Operation &op : rewriterFunc.getOps())
443 generate(&op, rewriterByteCodeWriter);
444 }
445 assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
446 "unexpected branches in rewriter function");
447
448 // Generate code for the matcher function.
449 DenseMap<Block *, ByteCodeAddr> blockToAddr;
450 llvm::ReversePostOrderTraversal<Region *> rpot(&matcherFunc.getBody());
451 ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
452 for (Block *block : rpot) {
453 // Keep track of where this block begins within the matcher function.
454 blockToAddr.try_emplace(block, matcherByteCode.size());
455 for (Operation &op : *block)
456 generate(&op, matcherByteCodeWriter);
457 }
458
459 // Resolve successor references in the matcher.
460 for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
461 ByteCodeAddr addr = blockToAddr[it.first];
462 for (unsigned offsetToFix : it.second)
463 std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
464 }
465 }
466
allocateMemoryIndices(FuncOp matcherFunc,ModuleOp rewriterModule)467 void Generator::allocateMemoryIndices(FuncOp matcherFunc,
468 ModuleOp rewriterModule) {
469 // Rewriters use simplistic allocation scheme that simply assigns an index to
470 // each result.
471 for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
472 ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
473 auto processRewriterValue = [&](Value val) {
474 valueToMemIndex.try_emplace(val, index++);
475 if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) {
476 Type elementTy = rangeType.getElementType();
477 if (elementTy.isa<pdl::TypeType>())
478 valueToRangeIndex.try_emplace(val, typeRangeIndex++);
479 else if (elementTy.isa<pdl::ValueType>())
480 valueToRangeIndex.try_emplace(val, valueRangeIndex++);
481 }
482 };
483
484 for (BlockArgument arg : rewriterFunc.getArguments())
485 processRewriterValue(arg);
486 rewriterFunc.getBody().walk([&](Operation *op) {
487 for (Value result : op->getResults())
488 processRewriterValue(result);
489 });
490 if (index > maxValueMemoryIndex)
491 maxValueMemoryIndex = index;
492 if (typeRangeIndex > maxTypeRangeMemoryIndex)
493 maxTypeRangeMemoryIndex = typeRangeIndex;
494 if (valueRangeIndex > maxValueRangeMemoryIndex)
495 maxValueRangeMemoryIndex = valueRangeIndex;
496 }
497
498 // The matcher function uses a more sophisticated numbering that tries to
499 // minimize the number of memory indices assigned. This is done by determining
500 // a live range of the values within the matcher, then the allocation is just
501 // finding the minimal number of overlapping live ranges. This is essentially
502 // a simplified form of register allocation where we don't necessarily have a
503 // limited number of registers, but we still want to minimize the number used.
504 DenseMap<Operation *, ByteCodeField> opToIndex;
505 matcherFunc.getBody().walk([&](Operation *op) {
506 opToIndex.insert(std::make_pair(op, opToIndex.size()));
507 });
508
509 // Liveness info for each of the defs within the matcher.
510 ByteCodeLiveRange::Allocator allocator;
511 DenseMap<Value, ByteCodeLiveRange> valueDefRanges;
512
513 // Assign the root operation being matched to slot 0.
514 BlockArgument rootOpArg = matcherFunc.getArgument(0);
515 valueToMemIndex[rootOpArg] = 0;
516
517 // Walk each of the blocks, computing the def interval that the value is used.
518 Liveness matcherLiveness(matcherFunc);
519 for (Block &block : matcherFunc.getBody()) {
520 const LivenessBlockInfo *info = matcherLiveness.getLiveness(&block);
521 assert(info && "expected liveness info for block");
522 auto processValue = [&](Value value, Operation *firstUseOrDef) {
523 // We don't need to process the root op argument, this value is always
524 // assigned to the first memory slot.
525 if (value == rootOpArg)
526 return;
527
528 // Set indices for the range of this block that the value is used.
529 auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
530 defRangeIt->second.liveness.insert(
531 opToIndex[firstUseOrDef],
532 opToIndex[info->getEndOperation(value, firstUseOrDef)],
533 /*dummyValue*/ 0);
534
535 // Check to see if this value is a range type.
536 if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) {
537 Type eleType = rangeTy.getElementType();
538 if (eleType.isa<pdl::TypeType>())
539 defRangeIt->second.typeRangeIndex = 0;
540 else if (eleType.isa<pdl::ValueType>())
541 defRangeIt->second.valueRangeIndex = 0;
542 }
543 };
544
545 // Process the live-ins of this block.
546 for (Value liveIn : info->in())
547 processValue(liveIn, &block.front());
548
549 // Process any new defs within this block.
550 for (Operation &op : block)
551 for (Value result : op.getResults())
552 processValue(result, &op);
553 }
554
555 // Greedily allocate memory slots using the computed def live ranges.
556 std::vector<ByteCodeLiveRange> allocatedIndices;
557 ByteCodeField numIndices = 1, numTypeRanges = 0, numValueRanges = 0;
558 for (auto &defIt : valueDefRanges) {
559 ByteCodeField &memIndex = valueToMemIndex[defIt.first];
560 ByteCodeLiveRange &defRange = defIt.second;
561
562 // Try to allocate to an existing index.
563 for (auto existingIndexIt : llvm::enumerate(allocatedIndices)) {
564 ByteCodeLiveRange &existingRange = existingIndexIt.value();
565 if (!defRange.overlaps(existingRange)) {
566 existingRange.unionWith(defRange);
567 memIndex = existingIndexIt.index() + 1;
568
569 if (defRange.typeRangeIndex) {
570 if (!existingRange.typeRangeIndex)
571 existingRange.typeRangeIndex = numTypeRanges++;
572 valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
573 } else if (defRange.valueRangeIndex) {
574 if (!existingRange.valueRangeIndex)
575 existingRange.valueRangeIndex = numValueRanges++;
576 valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
577 }
578 break;
579 }
580 }
581
582 // If no existing index could be used, add a new one.
583 if (memIndex == 0) {
584 allocatedIndices.emplace_back(allocator);
585 ByteCodeLiveRange &newRange = allocatedIndices.back();
586 newRange.unionWith(defRange);
587
588 // Allocate an index for type/value ranges.
589 if (defRange.typeRangeIndex) {
590 newRange.typeRangeIndex = numTypeRanges;
591 valueToRangeIndex[defIt.first] = numTypeRanges++;
592 } else if (defRange.valueRangeIndex) {
593 newRange.valueRangeIndex = numValueRanges;
594 valueToRangeIndex[defIt.first] = numValueRanges++;
595 }
596
597 memIndex = allocatedIndices.size();
598 ++numIndices;
599 }
600 }
601
602 // Update the max number of indices.
603 if (numIndices > maxValueMemoryIndex)
604 maxValueMemoryIndex = numIndices;
605 if (numTypeRanges > maxTypeRangeMemoryIndex)
606 maxTypeRangeMemoryIndex = numTypeRanges;
607 if (numValueRanges > maxValueRangeMemoryIndex)
608 maxValueRangeMemoryIndex = numValueRanges;
609 }
610
generate(Operation * op,ByteCodeWriter & writer)611 void Generator::generate(Operation *op, ByteCodeWriter &writer) {
612 TypeSwitch<Operation *>(op)
613 .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
614 pdl_interp::AreEqualOp, pdl_interp::BranchOp,
615 pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
616 pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
617 pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
618 pdl_interp::CreateAttributeOp, pdl_interp::CreateOperationOp,
619 pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp,
620 pdl_interp::EraseOp, pdl_interp::FinalizeOp,
621 pdl_interp::GetAttributeOp, pdl_interp::GetAttributeTypeOp,
622 pdl_interp::GetDefiningOpOp, pdl_interp::GetOperandOp,
623 pdl_interp::GetOperandsOp, pdl_interp::GetResultOp,
624 pdl_interp::GetResultsOp, pdl_interp::GetValueTypeOp,
625 pdl_interp::InferredTypesOp, pdl_interp::IsNotNullOp,
626 pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp,
627 pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp,
628 pdl_interp::SwitchTypesOp, pdl_interp::SwitchOperandCountOp,
629 pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>(
630 [&](auto interpOp) { this->generate(interpOp, writer); })
631 .Default([](Operation *) {
632 llvm_unreachable("unknown `pdl_interp` operation");
633 });
634 }
635
generate(pdl_interp::ApplyConstraintOp op,ByteCodeWriter & writer)636 void Generator::generate(pdl_interp::ApplyConstraintOp op,
637 ByteCodeWriter &writer) {
638 assert(constraintToMemIndex.count(op.name()) &&
639 "expected index for constraint function");
640 writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.name()],
641 op.constParamsAttr());
642 writer.appendPDLValueList(op.args());
643 writer.append(op.getSuccessors());
644 }
generate(pdl_interp::ApplyRewriteOp op,ByteCodeWriter & writer)645 void Generator::generate(pdl_interp::ApplyRewriteOp op,
646 ByteCodeWriter &writer) {
647 assert(externalRewriterToMemIndex.count(op.name()) &&
648 "expected index for rewrite function");
649 writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()],
650 op.constParamsAttr());
651 writer.appendPDLValueList(op.args());
652
653 ResultRange results = op.results();
654 writer.append(ByteCodeField(results.size()));
655 for (Value result : results) {
656 // In debug mode we also record the expected kind of the result, so that we
657 // can provide extra verification of the native rewrite function.
658 #ifndef NDEBUG
659 writer.appendPDLValueKind(result);
660 #endif
661
662 // Range results also need to append the range storage index.
663 if (result.getType().isa<pdl::RangeType>())
664 writer.append(getRangeStorageIndex(result));
665 writer.append(result);
666 }
667 }
generate(pdl_interp::AreEqualOp op,ByteCodeWriter & writer)668 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
669 Value lhs = op.lhs();
670 if (lhs.getType().isa<pdl::RangeType>()) {
671 writer.append(OpCode::AreRangesEqual);
672 writer.appendPDLValueKind(lhs);
673 writer.append(op.lhs(), op.rhs(), op.getSuccessors());
674 return;
675 }
676
677 writer.append(OpCode::AreEqual, lhs, op.rhs(), op.getSuccessors());
678 }
generate(pdl_interp::BranchOp op,ByteCodeWriter & writer)679 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
680 writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
681 }
generate(pdl_interp::CheckAttributeOp op,ByteCodeWriter & writer)682 void Generator::generate(pdl_interp::CheckAttributeOp op,
683 ByteCodeWriter &writer) {
684 writer.append(OpCode::AreEqual, op.attribute(), op.constantValue(),
685 op.getSuccessors());
686 }
generate(pdl_interp::CheckOperandCountOp op,ByteCodeWriter & writer)687 void Generator::generate(pdl_interp::CheckOperandCountOp op,
688 ByteCodeWriter &writer) {
689 writer.append(OpCode::CheckOperandCount, op.operation(), op.count(),
690 static_cast<ByteCodeField>(op.compareAtLeast()),
691 op.getSuccessors());
692 }
generate(pdl_interp::CheckOperationNameOp op,ByteCodeWriter & writer)693 void Generator::generate(pdl_interp::CheckOperationNameOp op,
694 ByteCodeWriter &writer) {
695 writer.append(OpCode::CheckOperationName, op.operation(),
696 OperationName(op.name(), ctx), op.getSuccessors());
697 }
generate(pdl_interp::CheckResultCountOp op,ByteCodeWriter & writer)698 void Generator::generate(pdl_interp::CheckResultCountOp op,
699 ByteCodeWriter &writer) {
700 writer.append(OpCode::CheckResultCount, op.operation(), op.count(),
701 static_cast<ByteCodeField>(op.compareAtLeast()),
702 op.getSuccessors());
703 }
generate(pdl_interp::CheckTypeOp op,ByteCodeWriter & writer)704 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
705 writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors());
706 }
generate(pdl_interp::CheckTypesOp op,ByteCodeWriter & writer)707 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
708 writer.append(OpCode::CheckTypes, op.value(), op.types(), op.getSuccessors());
709 }
generate(pdl_interp::CreateAttributeOp op,ByteCodeWriter & writer)710 void Generator::generate(pdl_interp::CreateAttributeOp op,
711 ByteCodeWriter &writer) {
712 // Simply repoint the memory index of the result to the constant.
713 getMemIndex(op.attribute()) = getMemIndex(op.value());
714 }
generate(pdl_interp::CreateOperationOp op,ByteCodeWriter & writer)715 void Generator::generate(pdl_interp::CreateOperationOp op,
716 ByteCodeWriter &writer) {
717 writer.append(OpCode::CreateOperation, op.operation(),
718 OperationName(op.name(), ctx));
719 writer.appendPDLValueList(op.operands());
720
721 // Add the attributes.
722 OperandRange attributes = op.attributes();
723 writer.append(static_cast<ByteCodeField>(attributes.size()));
724 for (auto it : llvm::zip(op.attributeNames(), op.attributes())) {
725 writer.append(
726 Identifier::get(std::get<0>(it).cast<StringAttr>().getValue(), ctx),
727 std::get<1>(it));
728 }
729 writer.appendPDLValueList(op.types());
730 }
generate(pdl_interp::CreateTypeOp op,ByteCodeWriter & writer)731 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
732 // Simply repoint the memory index of the result to the constant.
733 getMemIndex(op.result()) = getMemIndex(op.value());
734 }
generate(pdl_interp::CreateTypesOp op,ByteCodeWriter & writer)735 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
736 writer.append(OpCode::CreateTypes, op.result(),
737 getRangeStorageIndex(op.result()), op.value());
738 }
generate(pdl_interp::EraseOp op,ByteCodeWriter & writer)739 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
740 writer.append(OpCode::EraseOp, op.operation());
741 }
generate(pdl_interp::FinalizeOp op,ByteCodeWriter & writer)742 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
743 writer.append(OpCode::Finalize);
744 }
generate(pdl_interp::GetAttributeOp op,ByteCodeWriter & writer)745 void Generator::generate(pdl_interp::GetAttributeOp op,
746 ByteCodeWriter &writer) {
747 writer.append(OpCode::GetAttribute, op.attribute(), op.operation(),
748 Identifier::get(op.name(), ctx));
749 }
generate(pdl_interp::GetAttributeTypeOp op,ByteCodeWriter & writer)750 void Generator::generate(pdl_interp::GetAttributeTypeOp op,
751 ByteCodeWriter &writer) {
752 writer.append(OpCode::GetAttributeType, op.result(), op.value());
753 }
generate(pdl_interp::GetDefiningOpOp op,ByteCodeWriter & writer)754 void Generator::generate(pdl_interp::GetDefiningOpOp op,
755 ByteCodeWriter &writer) {
756 writer.append(OpCode::GetDefiningOp, op.operation());
757 writer.appendPDLValue(op.value());
758 }
generate(pdl_interp::GetOperandOp op,ByteCodeWriter & writer)759 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
760 uint32_t index = op.index();
761 if (index < 4)
762 writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
763 else
764 writer.append(OpCode::GetOperandN, index);
765 writer.append(op.operation(), op.value());
766 }
generate(pdl_interp::GetOperandsOp op,ByteCodeWriter & writer)767 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
768 Value result = op.value();
769 Optional<uint32_t> index = op.index();
770 writer.append(OpCode::GetOperands,
771 index.getValueOr(std::numeric_limits<uint32_t>::max()),
772 op.operation());
773 if (result.getType().isa<pdl::RangeType>())
774 writer.append(getRangeStorageIndex(result));
775 else
776 writer.append(std::numeric_limits<ByteCodeField>::max());
777 writer.append(result);
778 }
generate(pdl_interp::GetResultOp op,ByteCodeWriter & writer)779 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
780 uint32_t index = op.index();
781 if (index < 4)
782 writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
783 else
784 writer.append(OpCode::GetResultN, index);
785 writer.append(op.operation(), op.value());
786 }
generate(pdl_interp::GetResultsOp op,ByteCodeWriter & writer)787 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
788 Value result = op.value();
789 Optional<uint32_t> index = op.index();
790 writer.append(OpCode::GetResults,
791 index.getValueOr(std::numeric_limits<uint32_t>::max()),
792 op.operation());
793 if (result.getType().isa<pdl::RangeType>())
794 writer.append(getRangeStorageIndex(result));
795 else
796 writer.append(std::numeric_limits<ByteCodeField>::max());
797 writer.append(result);
798 }
generate(pdl_interp::GetValueTypeOp op,ByteCodeWriter & writer)799 void Generator::generate(pdl_interp::GetValueTypeOp op,
800 ByteCodeWriter &writer) {
801 if (op.getType().isa<pdl::RangeType>()) {
802 Value result = op.result();
803 writer.append(OpCode::GetValueRangeTypes, result,
804 getRangeStorageIndex(result), op.value());
805 } else {
806 writer.append(OpCode::GetValueType, op.result(), op.value());
807 }
808 }
809
generate(pdl_interp::InferredTypesOp op,ByteCodeWriter & writer)810 void Generator::generate(pdl_interp::InferredTypesOp op,
811 ByteCodeWriter &writer) {
812 // InferType maps to a null type as a marker for inferring result types.
813 getMemIndex(op.type()) = getMemIndex(Type());
814 }
generate(pdl_interp::IsNotNullOp op,ByteCodeWriter & writer)815 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
816 writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors());
817 }
generate(pdl_interp::RecordMatchOp op,ByteCodeWriter & writer)818 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
819 ByteCodeField patternIndex = patterns.size();
820 patterns.emplace_back(PDLByteCodePattern::create(
821 op, rewriterToAddr[op.rewriter().getLeafReference().getValue()]));
822 writer.append(OpCode::RecordMatch, patternIndex,
823 SuccessorRange(op.getOperation()), op.matchedOps());
824 writer.appendPDLValueList(op.inputs());
825 }
generate(pdl_interp::ReplaceOp op,ByteCodeWriter & writer)826 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
827 writer.append(OpCode::ReplaceOp, op.operation());
828 writer.appendPDLValueList(op.replValues());
829 }
generate(pdl_interp::SwitchAttributeOp op,ByteCodeWriter & writer)830 void Generator::generate(pdl_interp::SwitchAttributeOp op,
831 ByteCodeWriter &writer) {
832 writer.append(OpCode::SwitchAttribute, op.attribute(), op.caseValuesAttr(),
833 op.getSuccessors());
834 }
generate(pdl_interp::SwitchOperandCountOp op,ByteCodeWriter & writer)835 void Generator::generate(pdl_interp::SwitchOperandCountOp op,
836 ByteCodeWriter &writer) {
837 writer.append(OpCode::SwitchOperandCount, op.operation(), op.caseValuesAttr(),
838 op.getSuccessors());
839 }
generate(pdl_interp::SwitchOperationNameOp op,ByteCodeWriter & writer)840 void Generator::generate(pdl_interp::SwitchOperationNameOp op,
841 ByteCodeWriter &writer) {
842 auto cases = llvm::map_range(op.caseValuesAttr(), [&](Attribute attr) {
843 return OperationName(attr.cast<StringAttr>().getValue(), ctx);
844 });
845 writer.append(OpCode::SwitchOperationName, op.operation(), cases,
846 op.getSuccessors());
847 }
generate(pdl_interp::SwitchResultCountOp op,ByteCodeWriter & writer)848 void Generator::generate(pdl_interp::SwitchResultCountOp op,
849 ByteCodeWriter &writer) {
850 writer.append(OpCode::SwitchResultCount, op.operation(), op.caseValuesAttr(),
851 op.getSuccessors());
852 }
generate(pdl_interp::SwitchTypeOp op,ByteCodeWriter & writer)853 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
854 writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(),
855 op.getSuccessors());
856 }
generate(pdl_interp::SwitchTypesOp op,ByteCodeWriter & writer)857 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
858 writer.append(OpCode::SwitchTypes, op.value(), op.caseValuesAttr(),
859 op.getSuccessors());
860 }
861
862 //===----------------------------------------------------------------------===//
863 // PDLByteCode
864 //===----------------------------------------------------------------------===//
865
PDLByteCode(ModuleOp module,llvm::StringMap<PDLConstraintFunction> constraintFns,llvm::StringMap<PDLRewriteFunction> rewriteFns)866 PDLByteCode::PDLByteCode(ModuleOp module,
867 llvm::StringMap<PDLConstraintFunction> constraintFns,
868 llvm::StringMap<PDLRewriteFunction> rewriteFns) {
869 Generator generator(module.getContext(), uniquedData, matcherByteCode,
870 rewriterByteCode, patterns, maxValueMemoryIndex,
871 maxTypeRangeCount, maxValueRangeCount, constraintFns,
872 rewriteFns);
873 generator.generate(module);
874
875 // Initialize the external functions.
876 for (auto &it : constraintFns)
877 constraintFunctions.push_back(std::move(it.second));
878 for (auto &it : rewriteFns)
879 rewriteFunctions.push_back(std::move(it.second));
880 }
881
882 /// Initialize the given state such that it can be used to execute the current
883 /// bytecode.
initializeMutableState(PDLByteCodeMutableState & state) const884 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
885 state.memory.resize(maxValueMemoryIndex, nullptr);
886 state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange());
887 state.valueRangeMemory.resize(maxValueRangeCount, ValueRange());
888 state.currentPatternBenefits.reserve(patterns.size());
889 for (const PDLByteCodePattern &pattern : patterns)
890 state.currentPatternBenefits.push_back(pattern.getBenefit());
891 }
892
893 //===----------------------------------------------------------------------===//
894 // ByteCode Execution
895
896 namespace {
897 /// This class provides support for executing a bytecode stream.
898 class ByteCodeExecutor {
899 public:
ByteCodeExecutor(const ByteCodeField * curCodeIt,MutableArrayRef<const void * > memory,MutableArrayRef<TypeRange> typeRangeMemory,std::vector<llvm::OwningArrayRef<Type>> & allocatedTypeRangeMemory,MutableArrayRef<ValueRange> valueRangeMemory,std::vector<llvm::OwningArrayRef<Value>> & allocatedValueRangeMemory,ArrayRef<const void * > uniquedMemory,ArrayRef<ByteCodeField> code,ArrayRef<PatternBenefit> currentPatternBenefits,ArrayRef<PDLByteCodePattern> patterns,ArrayRef<PDLConstraintFunction> constraintFunctions,ArrayRef<PDLRewriteFunction> rewriteFunctions)900 ByteCodeExecutor(
901 const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory,
902 MutableArrayRef<TypeRange> typeRangeMemory,
903 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
904 MutableArrayRef<ValueRange> valueRangeMemory,
905 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
906 ArrayRef<const void *> uniquedMemory, ArrayRef<ByteCodeField> code,
907 ArrayRef<PatternBenefit> currentPatternBenefits,
908 ArrayRef<PDLByteCodePattern> patterns,
909 ArrayRef<PDLConstraintFunction> constraintFunctions,
910 ArrayRef<PDLRewriteFunction> rewriteFunctions)
911 : curCodeIt(curCodeIt), memory(memory), typeRangeMemory(typeRangeMemory),
912 allocatedTypeRangeMemory(allocatedTypeRangeMemory),
913 valueRangeMemory(valueRangeMemory),
914 allocatedValueRangeMemory(allocatedValueRangeMemory),
915 uniquedMemory(uniquedMemory), code(code),
916 currentPatternBenefits(currentPatternBenefits), patterns(patterns),
917 constraintFunctions(constraintFunctions),
918 rewriteFunctions(rewriteFunctions) {}
919
920 /// Start executing the code at the current bytecode index. `matches` is an
921 /// optional field provided when this function is executed in a matching
922 /// context.
923 void execute(PatternRewriter &rewriter,
924 SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
925 Optional<Location> mainRewriteLoc = {});
926
927 private:
928 /// Internal implementation of executing each of the bytecode commands.
929 void executeApplyConstraint(PatternRewriter &rewriter);
930 void executeApplyRewrite(PatternRewriter &rewriter);
931 void executeAreEqual();
932 void executeAreRangesEqual();
933 void executeBranch();
934 void executeCheckOperandCount();
935 void executeCheckOperationName();
936 void executeCheckResultCount();
937 void executeCheckTypes();
938 void executeCreateOperation(PatternRewriter &rewriter,
939 Location mainRewriteLoc);
940 void executeCreateTypes();
941 void executeEraseOp(PatternRewriter &rewriter);
942 void executeGetAttribute();
943 void executeGetAttributeType();
944 void executeGetDefiningOp();
945 void executeGetOperand(unsigned index);
946 void executeGetOperands();
947 void executeGetResult(unsigned index);
948 void executeGetResults();
949 void executeGetValueType();
950 void executeGetValueRangeTypes();
951 void executeIsNotNull();
952 void executeRecordMatch(PatternRewriter &rewriter,
953 SmallVectorImpl<PDLByteCode::MatchResult> &matches);
954 void executeReplaceOp(PatternRewriter &rewriter);
955 void executeSwitchAttribute();
956 void executeSwitchOperandCount();
957 void executeSwitchOperationName();
958 void executeSwitchResultCount();
959 void executeSwitchType();
960 void executeSwitchTypes();
961
962 /// Read a value from the bytecode buffer, optionally skipping a certain
963 /// number of prefix values. These methods always update the buffer to point
964 /// to the next field after the read data.
965 template <typename T = ByteCodeField>
read(size_t skipN=0)966 T read(size_t skipN = 0) {
967 curCodeIt += skipN;
968 return readImpl<T>();
969 }
read(size_t skipN=0)970 ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
971
972 /// Read a list of values from the bytecode buffer.
973 template <typename ValueT, typename T>
readList(SmallVectorImpl<T> & list)974 void readList(SmallVectorImpl<T> &list) {
975 list.clear();
976 for (unsigned i = 0, e = read(); i != e; ++i)
977 list.push_back(read<ValueT>());
978 }
979
980 /// Read a list of values from the bytecode buffer. The values may be encoded
981 /// as either Value or ValueRange elements.
readValueList(SmallVectorImpl<Value> & list)982 void readValueList(SmallVectorImpl<Value> &list) {
983 for (unsigned i = 0, e = read(); i != e; ++i) {
984 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
985 list.push_back(read<Value>());
986 } else {
987 ValueRange *values = read<ValueRange *>();
988 list.append(values->begin(), values->end());
989 }
990 }
991 }
992
993 /// Jump to a specific successor based on a predicate value.
selectJump(bool isTrue)994 void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
995 /// Jump to a specific successor based on a destination index.
selectJump(size_t destIndex)996 void selectJump(size_t destIndex) {
997 curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
998 }
999
1000 /// Handle a switch operation with the provided value and cases.
1001 template <typename T, typename RangeT, typename Comparator = std::equal_to<T>>
handleSwitch(const T & value,RangeT && cases,Comparator cmp={})1002 void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) {
1003 LLVM_DEBUG({
1004 llvm::dbgs() << " * Value: " << value << "\n"
1005 << " * Cases: ";
1006 llvm::interleaveComma(cases, llvm::dbgs());
1007 llvm::dbgs() << "\n";
1008 });
1009
1010 // Check to see if the attribute value is within the case list. Jump to
1011 // the correct successor index based on the result.
1012 for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
1013 if (cmp(*it, value))
1014 return selectJump(size_t((it - cases.begin()) + 1));
1015 selectJump(size_t(0));
1016 }
1017
1018 /// Internal implementation of reading various data types from the bytecode
1019 /// stream.
1020 template <typename T>
readFromMemory()1021 const void *readFromMemory() {
1022 size_t index = *curCodeIt++;
1023
1024 // If this type is an SSA value, it can only be stored in non-const memory.
1025 if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *,
1026 Value>::value ||
1027 index < memory.size())
1028 return memory[index];
1029
1030 // Otherwise, if this index is not inbounds it is uniqued.
1031 return uniquedMemory[index - memory.size()];
1032 }
1033 template <typename T>
readImpl()1034 std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
1035 return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
1036 }
1037 template <typename T>
1038 std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
1039 T>
readImpl()1040 readImpl() {
1041 return T(T::getFromOpaquePointer(readFromMemory<T>()));
1042 }
1043 template <typename T>
readImpl()1044 std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
1045 switch (read<PDLValue::Kind>()) {
1046 case PDLValue::Kind::Attribute:
1047 return read<Attribute>();
1048 case PDLValue::Kind::Operation:
1049 return read<Operation *>();
1050 case PDLValue::Kind::Type:
1051 return read<Type>();
1052 case PDLValue::Kind::Value:
1053 return read<Value>();
1054 case PDLValue::Kind::TypeRange:
1055 return read<TypeRange *>();
1056 case PDLValue::Kind::ValueRange:
1057 return read<ValueRange *>();
1058 }
1059 llvm_unreachable("unhandled PDLValue::Kind");
1060 }
1061 template <typename T>
readImpl()1062 std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
1063 static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
1064 "unexpected ByteCode address size");
1065 ByteCodeAddr result;
1066 std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
1067 curCodeIt += 2;
1068 return result;
1069 }
1070 template <typename T>
readImpl()1071 std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
1072 return *curCodeIt++;
1073 }
1074 template <typename T>
readImpl()1075 std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() {
1076 return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>());
1077 }
1078
1079 /// The underlying bytecode buffer.
1080 const ByteCodeField *curCodeIt;
1081
1082 /// The current execution memory.
1083 MutableArrayRef<const void *> memory;
1084 MutableArrayRef<TypeRange> typeRangeMemory;
1085 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
1086 MutableArrayRef<ValueRange> valueRangeMemory;
1087 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
1088
1089 /// References to ByteCode data necessary for execution.
1090 ArrayRef<const void *> uniquedMemory;
1091 ArrayRef<ByteCodeField> code;
1092 ArrayRef<PatternBenefit> currentPatternBenefits;
1093 ArrayRef<PDLByteCodePattern> patterns;
1094 ArrayRef<PDLConstraintFunction> constraintFunctions;
1095 ArrayRef<PDLRewriteFunction> rewriteFunctions;
1096 };
1097
1098 /// This class is an instantiation of the PDLResultList that provides access to
1099 /// the returned results. This API is not on `PDLResultList` to avoid
1100 /// overexposing access to information specific solely to the ByteCode.
1101 class ByteCodeRewriteResultList : public PDLResultList {
1102 public:
ByteCodeRewriteResultList(unsigned maxNumResults)1103 ByteCodeRewriteResultList(unsigned maxNumResults)
1104 : PDLResultList(maxNumResults) {}
1105
1106 /// Return the list of PDL results.
getResults()1107 MutableArrayRef<PDLValue> getResults() { return results; }
1108
1109 /// Return the type ranges allocated by this list.
getAllocatedTypeRanges()1110 MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
1111 return allocatedTypeRanges;
1112 }
1113
1114 /// Return the value ranges allocated by this list.
getAllocatedValueRanges()1115 MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
1116 return allocatedValueRanges;
1117 }
1118 };
1119 } // end anonymous namespace
1120
executeApplyConstraint(PatternRewriter & rewriter)1121 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
1122 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
1123 const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
1124 ArrayAttr constParams = read<ArrayAttr>();
1125 SmallVector<PDLValue, 16> args;
1126 readList<PDLValue>(args);
1127
1128 LLVM_DEBUG({
1129 llvm::dbgs() << " * Arguments: ";
1130 llvm::interleaveComma(args, llvm::dbgs());
1131 llvm::dbgs() << "\n * Parameters: " << constParams << "\n";
1132 });
1133
1134 // Invoke the constraint and jump to the proper destination.
1135 selectJump(succeeded(constraintFn(args, constParams, rewriter)));
1136 }
1137
executeApplyRewrite(PatternRewriter & rewriter)1138 void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
1139 LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
1140 const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
1141 ArrayAttr constParams = read<ArrayAttr>();
1142 SmallVector<PDLValue, 16> args;
1143 readList<PDLValue>(args);
1144
1145 LLVM_DEBUG({
1146 llvm::dbgs() << " * Arguments: ";
1147 llvm::interleaveComma(args, llvm::dbgs());
1148 llvm::dbgs() << "\n * Parameters: " << constParams << "\n";
1149 });
1150
1151 // Execute the rewrite function.
1152 ByteCodeField numResults = read();
1153 ByteCodeRewriteResultList results(numResults);
1154 rewriteFn(args, constParams, rewriter, results);
1155
1156 assert(results.getResults().size() == numResults &&
1157 "native PDL rewrite function returned unexpected number of results");
1158
1159 // Store the results in the bytecode memory.
1160 for (PDLValue &result : results.getResults()) {
1161 LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n");
1162
1163 // In debug mode we also verify the expected kind of the result.
1164 #ifndef NDEBUG
1165 assert(result.getKind() == read<PDLValue::Kind>() &&
1166 "native PDL rewrite function returned an unexpected type of result");
1167 #endif
1168
1169 // If the result is a range, we need to copy it over to the bytecodes
1170 // range memory.
1171 if (Optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
1172 unsigned rangeIndex = read();
1173 typeRangeMemory[rangeIndex] = *typeRange;
1174 memory[read()] = &typeRangeMemory[rangeIndex];
1175 } else if (Optional<ValueRange> valueRange =
1176 result.dyn_cast<ValueRange>()) {
1177 unsigned rangeIndex = read();
1178 valueRangeMemory[rangeIndex] = *valueRange;
1179 memory[read()] = &valueRangeMemory[rangeIndex];
1180 } else {
1181 memory[read()] = result.getAsOpaquePointer();
1182 }
1183 }
1184
1185 // Copy over any underlying storage allocated for result ranges.
1186 for (auto &it : results.getAllocatedTypeRanges())
1187 allocatedTypeRangeMemory.push_back(std::move(it));
1188 for (auto &it : results.getAllocatedValueRanges())
1189 allocatedValueRangeMemory.push_back(std::move(it));
1190 }
1191
executeAreEqual()1192 void ByteCodeExecutor::executeAreEqual() {
1193 LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1194 const void *lhs = read<const void *>();
1195 const void *rhs = read<const void *>();
1196
1197 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n");
1198 selectJump(lhs == rhs);
1199 }
1200
executeAreRangesEqual()1201 void ByteCodeExecutor::executeAreRangesEqual() {
1202 LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n");
1203 PDLValue::Kind valueKind = read<PDLValue::Kind>();
1204 const void *lhs = read<const void *>();
1205 const void *rhs = read<const void *>();
1206
1207 switch (valueKind) {
1208 case PDLValue::Kind::TypeRange: {
1209 const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs);
1210 const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs);
1211 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
1212 selectJump(*lhsRange == *rhsRange);
1213 break;
1214 }
1215 case PDLValue::Kind::ValueRange: {
1216 const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs);
1217 const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs);
1218 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
1219 selectJump(*lhsRange == *rhsRange);
1220 break;
1221 }
1222 default:
1223 llvm_unreachable("unexpected `AreRangesEqual` value kind");
1224 }
1225 }
1226
executeBranch()1227 void ByteCodeExecutor::executeBranch() {
1228 LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n");
1229 curCodeIt = &code[read<ByteCodeAddr>()];
1230 }
1231
executeCheckOperandCount()1232 void ByteCodeExecutor::executeCheckOperandCount() {
1233 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
1234 Operation *op = read<Operation *>();
1235 uint32_t expectedCount = read<uint32_t>();
1236 bool compareAtLeast = read();
1237
1238 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n"
1239 << " * Expected: " << expectedCount << "\n"
1240 << " * Comparator: "
1241 << (compareAtLeast ? ">=" : "==") << "\n");
1242 if (compareAtLeast)
1243 selectJump(op->getNumOperands() >= expectedCount);
1244 else
1245 selectJump(op->getNumOperands() == expectedCount);
1246 }
1247
executeCheckOperationName()1248 void ByteCodeExecutor::executeCheckOperationName() {
1249 LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
1250 Operation *op = read<Operation *>();
1251 OperationName expectedName = read<OperationName>();
1252
1253 LLVM_DEBUG(llvm::dbgs() << " * Found: \"" << op->getName() << "\"\n"
1254 << " * Expected: \"" << expectedName << "\"\n");
1255 selectJump(op->getName() == expectedName);
1256 }
1257
executeCheckResultCount()1258 void ByteCodeExecutor::executeCheckResultCount() {
1259 LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
1260 Operation *op = read<Operation *>();
1261 uint32_t expectedCount = read<uint32_t>();
1262 bool compareAtLeast = read();
1263
1264 LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n"
1265 << " * Expected: " << expectedCount << "\n"
1266 << " * Comparator: "
1267 << (compareAtLeast ? ">=" : "==") << "\n");
1268 if (compareAtLeast)
1269 selectJump(op->getNumResults() >= expectedCount);
1270 else
1271 selectJump(op->getNumResults() == expectedCount);
1272 }
1273
executeCheckTypes()1274 void ByteCodeExecutor::executeCheckTypes() {
1275 LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1276 TypeRange *lhs = read<TypeRange *>();
1277 Attribute rhs = read<Attribute>();
1278 LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
1279
1280 selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>());
1281 }
1282
executeCreateTypes()1283 void ByteCodeExecutor::executeCreateTypes() {
1284 LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n");
1285 unsigned memIndex = read();
1286 unsigned rangeIndex = read();
1287 ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>();
1288
1289 LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n");
1290
1291 // Allocate a buffer for this type range.
1292 llvm::OwningArrayRef<Type> storage(typesAttr.size());
1293 llvm::copy(typesAttr.getAsValueRange<TypeAttr>(), storage.begin());
1294 allocatedTypeRangeMemory.emplace_back(std::move(storage));
1295
1296 // Assign this to the range slot and use the range as the value for the
1297 // memory index.
1298 typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back();
1299 memory[memIndex] = &typeRangeMemory[rangeIndex];
1300 }
1301
executeCreateOperation(PatternRewriter & rewriter,Location mainRewriteLoc)1302 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
1303 Location mainRewriteLoc) {
1304 LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
1305
1306 unsigned memIndex = read();
1307 OperationState state(mainRewriteLoc, read<OperationName>());
1308 readValueList(state.operands);
1309 for (unsigned i = 0, e = read(); i != e; ++i) {
1310 Identifier name = read<Identifier>();
1311 if (Attribute attr = read<Attribute>())
1312 state.addAttribute(name, attr);
1313 }
1314
1315 for (unsigned i = 0, e = read(); i != e; ++i) {
1316 if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1317 state.types.push_back(read<Type>());
1318 continue;
1319 }
1320
1321 // If we find a null range, this signals that the types are infered.
1322 if (TypeRange *resultTypes = read<TypeRange *>()) {
1323 state.types.append(resultTypes->begin(), resultTypes->end());
1324 continue;
1325 }
1326
1327 // Handle the case where the operation has inferred types.
1328 InferTypeOpInterface::Concept *concept =
1329 state.name.getAbstractOperation()->getInterface<InferTypeOpInterface>();
1330
1331 // TODO: Handle failure.
1332 state.types.clear();
1333 if (failed(concept->inferReturnTypes(
1334 state.getContext(), state.location, state.operands,
1335 state.attributes.getDictionary(state.getContext()), state.regions,
1336 state.types)))
1337 return;
1338 break;
1339 }
1340
1341 Operation *resultOp = rewriter.createOperation(state);
1342 memory[memIndex] = resultOp;
1343
1344 LLVM_DEBUG({
1345 llvm::dbgs() << " * Attributes: "
1346 << state.attributes.getDictionary(state.getContext())
1347 << "\n * Operands: ";
1348 llvm::interleaveComma(state.operands, llvm::dbgs());
1349 llvm::dbgs() << "\n * Result Types: ";
1350 llvm::interleaveComma(state.types, llvm::dbgs());
1351 llvm::dbgs() << "\n * Result: " << *resultOp << "\n";
1352 });
1353 }
1354
executeEraseOp(PatternRewriter & rewriter)1355 void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
1356 LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
1357 Operation *op = read<Operation *>();
1358
1359 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
1360 rewriter.eraseOp(op);
1361 }
1362
executeGetAttribute()1363 void ByteCodeExecutor::executeGetAttribute() {
1364 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
1365 unsigned memIndex = read();
1366 Operation *op = read<Operation *>();
1367 Identifier attrName = read<Identifier>();
1368 Attribute attr = op->getAttr(attrName);
1369
1370 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
1371 << " * Attribute: " << attrName << "\n"
1372 << " * Result: " << attr << "\n");
1373 memory[memIndex] = attr.getAsOpaquePointer();
1374 }
1375
executeGetAttributeType()1376 void ByteCodeExecutor::executeGetAttributeType() {
1377 LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
1378 unsigned memIndex = read();
1379 Attribute attr = read<Attribute>();
1380 Type type = attr ? attr.getType() : Type();
1381
1382 LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n"
1383 << " * Result: " << type << "\n");
1384 memory[memIndex] = type.getAsOpaquePointer();
1385 }
1386
executeGetDefiningOp()1387 void ByteCodeExecutor::executeGetDefiningOp() {
1388 LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
1389 unsigned memIndex = read();
1390 Operation *op = nullptr;
1391 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1392 Value value = read<Value>();
1393 if (value)
1394 op = value.getDefiningOp();
1395 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
1396 } else {
1397 ValueRange *values = read<ValueRange *>();
1398 if (values && !values->empty()) {
1399 op = values->front().getDefiningOp();
1400 }
1401 LLVM_DEBUG(llvm::dbgs() << " * Values: " << values << "\n");
1402 }
1403
1404 LLVM_DEBUG(llvm::dbgs() << " * Result: " << op << "\n");
1405 memory[memIndex] = op;
1406 }
1407
executeGetOperand(unsigned index)1408 void ByteCodeExecutor::executeGetOperand(unsigned index) {
1409 Operation *op = read<Operation *>();
1410 unsigned memIndex = read();
1411 Value operand =
1412 index < op->getNumOperands() ? op->getOperand(index) : Value();
1413
1414 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
1415 << " * Index: " << index << "\n"
1416 << " * Result: " << operand << "\n");
1417 memory[memIndex] = operand.getAsOpaquePointer();
1418 }
1419
1420 /// This function is the internal implementation of `GetResults` and
1421 /// `GetOperands` that provides support for extracting a value range from the
1422 /// given operation.
1423 template <template <typename> class AttrSizedSegmentsT, typename RangeT>
1424 static void *
executeGetOperandsResults(RangeT values,Operation * op,unsigned index,ByteCodeField rangeIndex,StringRef attrSizedSegments,MutableArrayRef<ValueRange> & valueRangeMemory)1425 executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
1426 ByteCodeField rangeIndex, StringRef attrSizedSegments,
1427 MutableArrayRef<ValueRange> &valueRangeMemory) {
1428 // Check for the sentinel index that signals that all values should be
1429 // returned.
1430 if (index == std::numeric_limits<uint32_t>::max()) {
1431 LLVM_DEBUG(llvm::dbgs() << " * Getting all values\n");
1432 // `values` is already the full value range.
1433
1434 // Otherwise, check to see if this operation uses AttrSizedSegments.
1435 } else if (op->hasTrait<AttrSizedSegmentsT>()) {
1436 LLVM_DEBUG(llvm::dbgs()
1437 << " * Extracting values from `" << attrSizedSegments << "`\n");
1438
1439 auto segmentAttr = op->getAttrOfType<DenseElementsAttr>(attrSizedSegments);
1440 if (!segmentAttr || segmentAttr.getNumElements() <= index)
1441 return nullptr;
1442
1443 auto segments = segmentAttr.getValues<int32_t>();
1444 unsigned startIndex =
1445 std::accumulate(segments.begin(), segments.begin() + index, 0);
1446 values = values.slice(startIndex, *std::next(segments.begin(), index));
1447
1448 LLVM_DEBUG(llvm::dbgs() << " * Extracting range[" << startIndex << ", "
1449 << *std::next(segments.begin(), index) << "]\n");
1450
1451 // Otherwise, assume this is the last operand group of the operation.
1452 // FIXME: We currently don't support operations with
1453 // SameVariadicOperandSize/SameVariadicResultSize here given that we don't
1454 // have a way to detect it's presence.
1455 } else if (values.size() >= index) {
1456 LLVM_DEBUG(llvm::dbgs()
1457 << " * Treating values as trailing variadic range\n");
1458 values = values.drop_front(index);
1459
1460 // If we couldn't detect a way to compute the values, bail out.
1461 } else {
1462 return nullptr;
1463 }
1464
1465 // If the range index is valid, we are returning a range.
1466 if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) {
1467 valueRangeMemory[rangeIndex] = values;
1468 return &valueRangeMemory[rangeIndex];
1469 }
1470
1471 // If a range index wasn't provided, the range is required to be non-variadic.
1472 return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
1473 }
1474
executeGetOperands()1475 void ByteCodeExecutor::executeGetOperands() {
1476 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n");
1477 unsigned index = read<uint32_t>();
1478 Operation *op = read<Operation *>();
1479 ByteCodeField rangeIndex = read();
1480
1481 void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>(
1482 op->getOperands(), op, index, rangeIndex, "operand_segment_sizes",
1483 valueRangeMemory);
1484 if (!result)
1485 LLVM_DEBUG(llvm::dbgs() << " * Invalid operand range\n");
1486 memory[read()] = result;
1487 }
1488
executeGetResult(unsigned index)1489 void ByteCodeExecutor::executeGetResult(unsigned index) {
1490 Operation *op = read<Operation *>();
1491 unsigned memIndex = read();
1492 OpResult result =
1493 index < op->getNumResults() ? op->getResult(index) : OpResult();
1494
1495 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
1496 << " * Index: " << index << "\n"
1497 << " * Result: " << result << "\n");
1498 memory[memIndex] = result.getAsOpaquePointer();
1499 }
1500
executeGetResults()1501 void ByteCodeExecutor::executeGetResults() {
1502 LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n");
1503 unsigned index = read<uint32_t>();
1504 Operation *op = read<Operation *>();
1505 ByteCodeField rangeIndex = read();
1506
1507 void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>(
1508 op->getResults(), op, index, rangeIndex, "result_segment_sizes",
1509 valueRangeMemory);
1510 if (!result)
1511 LLVM_DEBUG(llvm::dbgs() << " * Invalid result range\n");
1512 memory[read()] = result;
1513 }
1514
executeGetValueType()1515 void ByteCodeExecutor::executeGetValueType() {
1516 LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
1517 unsigned memIndex = read();
1518 Value value = read<Value>();
1519 Type type = value ? value.getType() : Type();
1520
1521 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"
1522 << " * Result: " << type << "\n");
1523 memory[memIndex] = type.getAsOpaquePointer();
1524 }
1525
executeGetValueRangeTypes()1526 void ByteCodeExecutor::executeGetValueRangeTypes() {
1527 LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n");
1528 unsigned memIndex = read();
1529 unsigned rangeIndex = read();
1530 ValueRange *values = read<ValueRange *>();
1531 if (!values) {
1532 LLVM_DEBUG(llvm::dbgs() << " * Values: <NULL>\n\n");
1533 memory[memIndex] = nullptr;
1534 return;
1535 }
1536
1537 LLVM_DEBUG({
1538 llvm::dbgs() << " * Values (" << values->size() << "): ";
1539 llvm::interleaveComma(*values, llvm::dbgs());
1540 llvm::dbgs() << "\n * Result: ";
1541 llvm::interleaveComma(values->getType(), llvm::dbgs());
1542 llvm::dbgs() << "\n";
1543 });
1544 typeRangeMemory[rangeIndex] = values->getType();
1545 memory[memIndex] = &typeRangeMemory[rangeIndex];
1546 }
1547
executeIsNotNull()1548 void ByteCodeExecutor::executeIsNotNull() {
1549 LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
1550 const void *value = read<const void *>();
1551
1552 LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
1553 selectJump(value != nullptr);
1554 }
1555
executeRecordMatch(PatternRewriter & rewriter,SmallVectorImpl<PDLByteCode::MatchResult> & matches)1556 void ByteCodeExecutor::executeRecordMatch(
1557 PatternRewriter &rewriter,
1558 SmallVectorImpl<PDLByteCode::MatchResult> &matches) {
1559 LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
1560 unsigned patternIndex = read();
1561 PatternBenefit benefit = currentPatternBenefits[patternIndex];
1562 const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
1563
1564 // If the benefit of the pattern is impossible, skip the processing of the
1565 // rest of the pattern.
1566 if (benefit.isImpossibleToMatch()) {
1567 LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n");
1568 curCodeIt = dest;
1569 return;
1570 }
1571
1572 // Create a fused location containing the locations of each of the
1573 // operations used in the match. This will be used as the location for
1574 // created operations during the rewrite that don't already have an
1575 // explicit location set.
1576 unsigned numMatchLocs = read();
1577 SmallVector<Location, 4> matchLocs;
1578 matchLocs.reserve(numMatchLocs);
1579 for (unsigned i = 0; i != numMatchLocs; ++i)
1580 matchLocs.push_back(read<Operation *>()->getLoc());
1581 Location matchLoc = rewriter.getFusedLoc(matchLocs);
1582
1583 LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n"
1584 << " * Location: " << matchLoc << "\n");
1585 matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
1586 PDLByteCode::MatchResult &match = matches.back();
1587
1588 // Record all of the inputs to the match. If any of the inputs are ranges, we
1589 // will also need to remap the range pointer to memory stored in the match
1590 // state.
1591 unsigned numInputs = read();
1592 match.values.reserve(numInputs);
1593 match.typeRangeValues.reserve(numInputs);
1594 match.valueRangeValues.reserve(numInputs);
1595 for (unsigned i = 0; i < numInputs; ++i) {
1596 switch (read<PDLValue::Kind>()) {
1597 case PDLValue::Kind::TypeRange:
1598 match.typeRangeValues.push_back(*read<TypeRange *>());
1599 match.values.push_back(&match.typeRangeValues.back());
1600 break;
1601 case PDLValue::Kind::ValueRange:
1602 match.valueRangeValues.push_back(*read<ValueRange *>());
1603 match.values.push_back(&match.valueRangeValues.back());
1604 break;
1605 default:
1606 match.values.push_back(read<const void *>());
1607 break;
1608 }
1609 }
1610 curCodeIt = dest;
1611 }
1612
executeReplaceOp(PatternRewriter & rewriter)1613 void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
1614 LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
1615 Operation *op = read<Operation *>();
1616 SmallVector<Value, 16> args;
1617 readValueList(args);
1618
1619 LLVM_DEBUG({
1620 llvm::dbgs() << " * Operation: " << *op << "\n"
1621 << " * Values: ";
1622 llvm::interleaveComma(args, llvm::dbgs());
1623 llvm::dbgs() << "\n";
1624 });
1625 rewriter.replaceOp(op, args);
1626 }
1627
executeSwitchAttribute()1628 void ByteCodeExecutor::executeSwitchAttribute() {
1629 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
1630 Attribute value = read<Attribute>();
1631 ArrayAttr cases = read<ArrayAttr>();
1632 handleSwitch(value, cases);
1633 }
1634
executeSwitchOperandCount()1635 void ByteCodeExecutor::executeSwitchOperandCount() {
1636 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
1637 Operation *op = read<Operation *>();
1638 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1639
1640 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
1641 handleSwitch(op->getNumOperands(), cases);
1642 }
1643
executeSwitchOperationName()1644 void ByteCodeExecutor::executeSwitchOperationName() {
1645 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
1646 OperationName value = read<Operation *>()->getName();
1647 size_t caseCount = read();
1648
1649 // The operation names are stored in-line, so to print them out for
1650 // debugging purposes we need to read the array before executing the
1651 // switch so that we can display all of the possible values.
1652 LLVM_DEBUG({
1653 const ByteCodeField *prevCodeIt = curCodeIt;
1654 llvm::dbgs() << " * Value: " << value << "\n"
1655 << " * Cases: ";
1656 llvm::interleaveComma(
1657 llvm::map_range(llvm::seq<size_t>(0, caseCount),
1658 [&](size_t) { return read<OperationName>(); }),
1659 llvm::dbgs());
1660 llvm::dbgs() << "\n";
1661 curCodeIt = prevCodeIt;
1662 });
1663
1664 // Try to find the switch value within any of the cases.
1665 for (size_t i = 0; i != caseCount; ++i) {
1666 if (read<OperationName>() == value) {
1667 curCodeIt += (caseCount - i - 1);
1668 return selectJump(i + 1);
1669 }
1670 }
1671 selectJump(size_t(0));
1672 }
1673
executeSwitchResultCount()1674 void ByteCodeExecutor::executeSwitchResultCount() {
1675 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
1676 Operation *op = read<Operation *>();
1677 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1678
1679 LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
1680 handleSwitch(op->getNumResults(), cases);
1681 }
1682
executeSwitchType()1683 void ByteCodeExecutor::executeSwitchType() {
1684 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
1685 Type value = read<Type>();
1686 auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
1687 handleSwitch(value, cases);
1688 }
1689
executeSwitchTypes()1690 void ByteCodeExecutor::executeSwitchTypes() {
1691 LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n");
1692 TypeRange *value = read<TypeRange *>();
1693 auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
1694 if (!value) {
1695 LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n");
1696 return selectJump(size_t(0));
1697 }
1698 handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) {
1699 return value == caseValue.getAsValueRange<TypeAttr>();
1700 });
1701 }
1702
execute(PatternRewriter & rewriter,SmallVectorImpl<PDLByteCode::MatchResult> * matches,Optional<Location> mainRewriteLoc)1703 void ByteCodeExecutor::execute(
1704 PatternRewriter &rewriter,
1705 SmallVectorImpl<PDLByteCode::MatchResult> *matches,
1706 Optional<Location> mainRewriteLoc) {
1707 while (true) {
1708 OpCode opCode = static_cast<OpCode>(read());
1709 switch (opCode) {
1710 case ApplyConstraint:
1711 executeApplyConstraint(rewriter);
1712 break;
1713 case ApplyRewrite:
1714 executeApplyRewrite(rewriter);
1715 break;
1716 case AreEqual:
1717 executeAreEqual();
1718 break;
1719 case AreRangesEqual:
1720 executeAreRangesEqual();
1721 break;
1722 case Branch:
1723 executeBranch();
1724 break;
1725 case CheckOperandCount:
1726 executeCheckOperandCount();
1727 break;
1728 case CheckOperationName:
1729 executeCheckOperationName();
1730 break;
1731 case CheckResultCount:
1732 executeCheckResultCount();
1733 break;
1734 case CheckTypes:
1735 executeCheckTypes();
1736 break;
1737 case CreateOperation:
1738 executeCreateOperation(rewriter, *mainRewriteLoc);
1739 break;
1740 case CreateTypes:
1741 executeCreateTypes();
1742 break;
1743 case EraseOp:
1744 executeEraseOp(rewriter);
1745 break;
1746 case Finalize:
1747 LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n\n");
1748 return;
1749 case GetAttribute:
1750 executeGetAttribute();
1751 break;
1752 case GetAttributeType:
1753 executeGetAttributeType();
1754 break;
1755 case GetDefiningOp:
1756 executeGetDefiningOp();
1757 break;
1758 case GetOperand0:
1759 case GetOperand1:
1760 case GetOperand2:
1761 case GetOperand3: {
1762 unsigned index = opCode - GetOperand0;
1763 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n");
1764 executeGetOperand(index);
1765 break;
1766 }
1767 case GetOperandN:
1768 LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n");
1769 executeGetOperand(read<uint32_t>());
1770 break;
1771 case GetOperands:
1772 executeGetOperands();
1773 break;
1774 case GetResult0:
1775 case GetResult1:
1776 case GetResult2:
1777 case GetResult3: {
1778 unsigned index = opCode - GetResult0;
1779 LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n");
1780 executeGetResult(index);
1781 break;
1782 }
1783 case GetResultN:
1784 LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n");
1785 executeGetResult(read<uint32_t>());
1786 break;
1787 case GetResults:
1788 executeGetResults();
1789 break;
1790 case GetValueType:
1791 executeGetValueType();
1792 break;
1793 case GetValueRangeTypes:
1794 executeGetValueRangeTypes();
1795 break;
1796 case IsNotNull:
1797 executeIsNotNull();
1798 break;
1799 case RecordMatch:
1800 assert(matches &&
1801 "expected matches to be provided when executing the matcher");
1802 executeRecordMatch(rewriter, *matches);
1803 break;
1804 case ReplaceOp:
1805 executeReplaceOp(rewriter);
1806 break;
1807 case SwitchAttribute:
1808 executeSwitchAttribute();
1809 break;
1810 case SwitchOperandCount:
1811 executeSwitchOperandCount();
1812 break;
1813 case SwitchOperationName:
1814 executeSwitchOperationName();
1815 break;
1816 case SwitchResultCount:
1817 executeSwitchResultCount();
1818 break;
1819 case SwitchType:
1820 executeSwitchType();
1821 break;
1822 case SwitchTypes:
1823 executeSwitchTypes();
1824 break;
1825 }
1826 LLVM_DEBUG(llvm::dbgs() << "\n");
1827 }
1828 }
1829
1830 /// Run the pattern matcher on the given root operation, collecting the matched
1831 /// patterns in `matches`.
match(Operation * op,PatternRewriter & rewriter,SmallVectorImpl<MatchResult> & matches,PDLByteCodeMutableState & state) const1832 void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
1833 SmallVectorImpl<MatchResult> &matches,
1834 PDLByteCodeMutableState &state) const {
1835 // The first memory slot is always the root operation.
1836 state.memory[0] = op;
1837
1838 // The matcher function always starts at code address 0.
1839 ByteCodeExecutor executor(
1840 matcherByteCode.data(), state.memory, state.typeRangeMemory,
1841 state.allocatedTypeRangeMemory, state.valueRangeMemory,
1842 state.allocatedValueRangeMemory, uniquedData, matcherByteCode,
1843 state.currentPatternBenefits, patterns, constraintFunctions,
1844 rewriteFunctions);
1845 executor.execute(rewriter, &matches);
1846
1847 // Order the found matches by benefit.
1848 std::stable_sort(matches.begin(), matches.end(),
1849 [](const MatchResult &lhs, const MatchResult &rhs) {
1850 return lhs.benefit > rhs.benefit;
1851 });
1852 }
1853
1854 /// Run the rewriter of the given pattern on the root operation `op`.
rewrite(PatternRewriter & rewriter,const MatchResult & match,PDLByteCodeMutableState & state) const1855 void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match,
1856 PDLByteCodeMutableState &state) const {
1857 // The arguments of the rewrite function are stored at the start of the
1858 // memory buffer.
1859 llvm::copy(match.values, state.memory.begin());
1860
1861 ByteCodeExecutor executor(
1862 &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
1863 state.typeRangeMemory, state.allocatedTypeRangeMemory,
1864 state.valueRangeMemory, state.allocatedValueRangeMemory, uniquedData,
1865 rewriterByteCode, state.currentPatternBenefits, patterns,
1866 constraintFunctions, rewriteFunctions);
1867 executor.execute(rewriter, /*matches=*/nullptr, match.location);
1868 }
1869