1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements MLIR to byte-code generation and the interpreter.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "ByteCode.h"
14 #include "mlir/Analysis/Liveness.h"
15 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
16 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/RegionGraphTraits.h"
19 #include "llvm/ADT/IntervalMap.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/Format.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include <numeric>
26 
27 #define DEBUG_TYPE "pdl-bytecode"
28 
29 using namespace mlir;
30 using namespace mlir::detail;
31 
32 //===----------------------------------------------------------------------===//
33 // PDLByteCodePattern
34 //===----------------------------------------------------------------------===//
35 
create(pdl_interp::RecordMatchOp matchOp,ByteCodeAddr rewriterAddr)36 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
37                                               ByteCodeAddr rewriterAddr) {
38   SmallVector<StringRef, 8> generatedOps;
39   if (ArrayAttr generatedOpsAttr = matchOp.generatedOpsAttr())
40     generatedOps =
41         llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
42 
43   PatternBenefit benefit = matchOp.benefit();
44   MLIRContext *ctx = matchOp.getContext();
45 
46   // Check to see if this is pattern matches a specific operation type.
47   if (Optional<StringRef> rootKind = matchOp.rootKind())
48     return PDLByteCodePattern(rewriterAddr, *rootKind, benefit, ctx,
49                               generatedOps);
50   return PDLByteCodePattern(rewriterAddr, MatchAnyOpTypeTag(), benefit, ctx,
51                             generatedOps);
52 }
53 
54 //===----------------------------------------------------------------------===//
55 // PDLByteCodeMutableState
56 //===----------------------------------------------------------------------===//
57 
58 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
59 /// to the position of the pattern within the range returned by
60 /// `PDLByteCode::getPatterns`.
updatePatternBenefit(unsigned patternIndex,PatternBenefit benefit)61 void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
62                                                    PatternBenefit benefit) {
63   currentPatternBenefits[patternIndex] = benefit;
64 }
65 
66 /// Cleanup any allocated state after a full match/rewrite has been completed.
67 /// This method should be called irregardless of whether the match+rewrite was a
68 /// success or not.
cleanupAfterMatchAndRewrite()69 void PDLByteCodeMutableState::cleanupAfterMatchAndRewrite() {
70   allocatedTypeRangeMemory.clear();
71   allocatedValueRangeMemory.clear();
72 }
73 
74 //===----------------------------------------------------------------------===//
75 // Bytecode OpCodes
76 //===----------------------------------------------------------------------===//
77 
78 namespace {
79 enum OpCode : ByteCodeField {
80   /// Apply an externally registered constraint.
81   ApplyConstraint,
82   /// Apply an externally registered rewrite.
83   ApplyRewrite,
84   /// Check if two generic values are equal.
85   AreEqual,
86   /// Check if two ranges are equal.
87   AreRangesEqual,
88   /// Unconditional branch.
89   Branch,
90   /// Compare the operand count of an operation with a constant.
91   CheckOperandCount,
92   /// Compare the name of an operation with a constant.
93   CheckOperationName,
94   /// Compare the result count of an operation with a constant.
95   CheckResultCount,
96   /// Compare a range of types to a constant range of types.
97   CheckTypes,
98   /// Create an operation.
99   CreateOperation,
100   /// Create a range of types.
101   CreateTypes,
102   /// Erase an operation.
103   EraseOp,
104   /// Terminate a matcher or rewrite sequence.
105   Finalize,
106   /// Get a specific attribute of an operation.
107   GetAttribute,
108   /// Get the type of an attribute.
109   GetAttributeType,
110   /// Get the defining operation of a value.
111   GetDefiningOp,
112   /// Get a specific operand of an operation.
113   GetOperand0,
114   GetOperand1,
115   GetOperand2,
116   GetOperand3,
117   GetOperandN,
118   /// Get a specific operand group of an operation.
119   GetOperands,
120   /// Get a specific result of an operation.
121   GetResult0,
122   GetResult1,
123   GetResult2,
124   GetResult3,
125   GetResultN,
126   /// Get a specific result group of an operation.
127   GetResults,
128   /// Get the type of a value.
129   GetValueType,
130   /// Get the types of a value range.
131   GetValueRangeTypes,
132   /// Check if a generic value is not null.
133   IsNotNull,
134   /// Record a successful pattern match.
135   RecordMatch,
136   /// Replace an operation.
137   ReplaceOp,
138   /// Compare an attribute with a set of constants.
139   SwitchAttribute,
140   /// Compare the operand count of an operation with a set of constants.
141   SwitchOperandCount,
142   /// Compare the name of an operation with a set of constants.
143   SwitchOperationName,
144   /// Compare the result count of an operation with a set of constants.
145   SwitchResultCount,
146   /// Compare a type with a set of constants.
147   SwitchType,
148   /// Compare a range of types with a set of constants.
149   SwitchTypes,
150 };
151 } // end anonymous namespace
152 
153 //===----------------------------------------------------------------------===//
154 // ByteCode Generation
155 //===----------------------------------------------------------------------===//
156 
157 //===----------------------------------------------------------------------===//
158 // Generator
159 
160 namespace {
161 struct ByteCodeWriter;
162 
163 /// This class represents the main generator for the pattern bytecode.
164 class Generator {
165 public:
Generator(MLIRContext * ctx,std::vector<const void * > & uniquedData,SmallVectorImpl<ByteCodeField> & matcherByteCode,SmallVectorImpl<ByteCodeField> & rewriterByteCode,SmallVectorImpl<PDLByteCodePattern> & patterns,ByteCodeField & maxValueMemoryIndex,ByteCodeField & maxTypeRangeMemoryIndex,ByteCodeField & maxValueRangeMemoryIndex,llvm::StringMap<PDLConstraintFunction> & constraintFns,llvm::StringMap<PDLRewriteFunction> & rewriteFns)166   Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
167             SmallVectorImpl<ByteCodeField> &matcherByteCode,
168             SmallVectorImpl<ByteCodeField> &rewriterByteCode,
169             SmallVectorImpl<PDLByteCodePattern> &patterns,
170             ByteCodeField &maxValueMemoryIndex,
171             ByteCodeField &maxTypeRangeMemoryIndex,
172             ByteCodeField &maxValueRangeMemoryIndex,
173             llvm::StringMap<PDLConstraintFunction> &constraintFns,
174             llvm::StringMap<PDLRewriteFunction> &rewriteFns)
175       : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
176         rewriterByteCode(rewriterByteCode), patterns(patterns),
177         maxValueMemoryIndex(maxValueMemoryIndex),
178         maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
179         maxValueRangeMemoryIndex(maxValueRangeMemoryIndex) {
180     for (auto it : llvm::enumerate(constraintFns))
181       constraintToMemIndex.try_emplace(it.value().first(), it.index());
182     for (auto it : llvm::enumerate(rewriteFns))
183       externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
184   }
185 
186   /// Generate the bytecode for the given PDL interpreter module.
187   void generate(ModuleOp module);
188 
189   /// Return the memory index to use for the given value.
getMemIndex(Value value)190   ByteCodeField &getMemIndex(Value value) {
191     assert(valueToMemIndex.count(value) &&
192            "expected memory index to be assigned");
193     return valueToMemIndex[value];
194   }
195 
196   /// Return the range memory index used to store the given range value.
getRangeStorageIndex(Value value)197   ByteCodeField &getRangeStorageIndex(Value value) {
198     assert(valueToRangeIndex.count(value) &&
199            "expected range index to be assigned");
200     return valueToRangeIndex[value];
201   }
202 
203   /// Return an index to use when referring to the given data that is uniqued in
204   /// the MLIR context.
205   template <typename T>
206   std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
getMemIndex(T val)207   getMemIndex(T val) {
208     const void *opaqueVal = val.getAsOpaquePointer();
209 
210     // Get or insert a reference to this value.
211     auto it = uniquedDataToMemIndex.try_emplace(
212         opaqueVal, maxValueMemoryIndex + uniquedData.size());
213     if (it.second)
214       uniquedData.push_back(opaqueVal);
215     return it.first->second;
216   }
217 
218 private:
219   /// Allocate memory indices for the results of operations within the matcher
220   /// and rewriters.
221   void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule);
222 
223   /// Generate the bytecode for the given operation.
224   void generate(Operation *op, ByteCodeWriter &writer);
225   void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
226   void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
227   void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
228   void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
229   void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
230   void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
231   void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
232   void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
233   void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
234   void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
235   void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
236   void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
237   void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
238   void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
239   void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
240   void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
241   void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
242   void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
243   void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
244   void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
245   void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
246   void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
247   void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
248   void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
249   void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer);
250   void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
251   void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
252   void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
253   void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
254   void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
255   void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
256   void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
257   void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
258   void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
259 
260   /// Mapping from value to its corresponding memory index.
261   DenseMap<Value, ByteCodeField> valueToMemIndex;
262 
263   /// Mapping from a range value to its corresponding range storage index.
264   DenseMap<Value, ByteCodeField> valueToRangeIndex;
265 
266   /// Mapping from the name of an externally registered rewrite to its index in
267   /// the bytecode registry.
268   llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
269 
270   /// Mapping from the name of an externally registered constraint to its index
271   /// in the bytecode registry.
272   llvm::StringMap<ByteCodeField> constraintToMemIndex;
273 
274   /// Mapping from rewriter function name to the bytecode address of the
275   /// rewriter function in byte.
276   llvm::StringMap<ByteCodeAddr> rewriterToAddr;
277 
278   /// Mapping from a uniqued storage object to its memory index within
279   /// `uniquedData`.
280   DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
281 
282   /// The current MLIR context.
283   MLIRContext *ctx;
284 
285   /// Data of the ByteCode class to be populated.
286   std::vector<const void *> &uniquedData;
287   SmallVectorImpl<ByteCodeField> &matcherByteCode;
288   SmallVectorImpl<ByteCodeField> &rewriterByteCode;
289   SmallVectorImpl<PDLByteCodePattern> &patterns;
290   ByteCodeField &maxValueMemoryIndex;
291   ByteCodeField &maxTypeRangeMemoryIndex;
292   ByteCodeField &maxValueRangeMemoryIndex;
293 };
294 
295 /// This class provides utilities for writing a bytecode stream.
296 struct ByteCodeWriter {
ByteCodeWriter__anon6093d5b70211::ByteCodeWriter297   ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
298       : bytecode(bytecode), generator(generator) {}
299 
300   /// Append a field to the bytecode.
append__anon6093d5b70211::ByteCodeWriter301   void append(ByteCodeField field) { bytecode.push_back(field); }
append__anon6093d5b70211::ByteCodeWriter302   void append(OpCode opCode) { bytecode.push_back(opCode); }
303 
304   /// Append an address to the bytecode.
append__anon6093d5b70211::ByteCodeWriter305   void append(ByteCodeAddr field) {
306     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
307                   "unexpected ByteCode address size");
308 
309     ByteCodeField fieldParts[2];
310     std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
311     bytecode.append({fieldParts[0], fieldParts[1]});
312   }
313 
314   /// Append a successor range to the bytecode, the exact address will need to
315   /// be resolved later.
append__anon6093d5b70211::ByteCodeWriter316   void append(SuccessorRange successors) {
317     // Add back references to the any successors so that the address can be
318     // resolved later.
319     for (Block *successor : successors) {
320       unresolvedSuccessorRefs[successor].push_back(bytecode.size());
321       append(ByteCodeAddr(0));
322     }
323   }
324 
325   /// Append a range of values that will be read as generic PDLValues.
appendPDLValueList__anon6093d5b70211::ByteCodeWriter326   void appendPDLValueList(OperandRange values) {
327     bytecode.push_back(values.size());
328     for (Value value : values)
329       appendPDLValue(value);
330   }
331 
332   /// Append a value as a PDLValue.
appendPDLValue__anon6093d5b70211::ByteCodeWriter333   void appendPDLValue(Value value) {
334     appendPDLValueKind(value);
335     append(value);
336   }
337 
338   /// Append the PDLValue::Kind of the given value.
appendPDLValueKind__anon6093d5b70211::ByteCodeWriter339   void appendPDLValueKind(Value value) {
340     // Append the type of the value in addition to the value itself.
341     PDLValue::Kind kind =
342         TypeSwitch<Type, PDLValue::Kind>(value.getType())
343             .Case<pdl::AttributeType>(
344                 [](Type) { return PDLValue::Kind::Attribute; })
345             .Case<pdl::OperationType>(
346                 [](Type) { return PDLValue::Kind::Operation; })
347             .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
348               if (rangeTy.getElementType().isa<pdl::TypeType>())
349                 return PDLValue::Kind::TypeRange;
350               return PDLValue::Kind::ValueRange;
351             })
352             .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; })
353             .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; });
354     bytecode.push_back(static_cast<ByteCodeField>(kind));
355   }
356 
357   /// Check if the given class `T` has an iterator type.
358   template <typename T, typename... Args>
359   using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
360 
361   /// Append a value that will be stored in a memory slot and not inline within
362   /// the bytecode.
363   template <typename T>
364   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
365                    std::is_pointer<T>::value>
append__anon6093d5b70211::ByteCodeWriter366   append(T value) {
367     bytecode.push_back(generator.getMemIndex(value));
368   }
369 
370   /// Append a range of values.
371   template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
372   std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
append__anon6093d5b70211::ByteCodeWriter373   append(T range) {
374     bytecode.push_back(llvm::size(range));
375     for (auto it : range)
376       append(it);
377   }
378 
379   /// Append a variadic number of fields to the bytecode.
380   template <typename FieldTy, typename Field2Ty, typename... FieldTys>
append__anon6093d5b70211::ByteCodeWriter381   void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
382     append(field);
383     append(field2, fields...);
384   }
385 
386   /// Successor references in the bytecode that have yet to be resolved.
387   DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
388 
389   /// The underlying bytecode buffer.
390   SmallVectorImpl<ByteCodeField> &bytecode;
391 
392   /// The main generator producing PDL.
393   Generator &generator;
394 };
395 
396 /// This class represents a live range of PDL Interpreter values, containing
397 /// information about when values are live within a match/rewrite.
398 struct ByteCodeLiveRange {
399   using Set = llvm::IntervalMap<ByteCodeField, char, 16>;
400   using Allocator = Set::Allocator;
401 
ByteCodeLiveRange__anon6093d5b70211::ByteCodeLiveRange402   ByteCodeLiveRange(Allocator &alloc) : liveness(alloc) {}
403 
404   /// Union this live range with the one provided.
unionWith__anon6093d5b70211::ByteCodeLiveRange405   void unionWith(const ByteCodeLiveRange &rhs) {
406     for (auto it = rhs.liveness.begin(), e = rhs.liveness.end(); it != e; ++it)
407       liveness.insert(it.start(), it.stop(), /*dummyValue*/ 0);
408   }
409 
410   /// Returns true if this range overlaps with the one provided.
overlaps__anon6093d5b70211::ByteCodeLiveRange411   bool overlaps(const ByteCodeLiveRange &rhs) const {
412     return llvm::IntervalMapOverlaps<Set, Set>(liveness, rhs.liveness).valid();
413   }
414 
415   /// A map representing the ranges of the match/rewrite that a value is live in
416   /// the interpreter.
417   llvm::IntervalMap<ByteCodeField, char, 16> liveness;
418 
419   /// The type range storage index for this range.
420   Optional<unsigned> typeRangeIndex;
421 
422   /// The value range storage index for this range.
423   Optional<unsigned> valueRangeIndex;
424 };
425 } // end anonymous namespace
426 
generate(ModuleOp module)427 void Generator::generate(ModuleOp module) {
428   FuncOp matcherFunc = module.lookupSymbol<FuncOp>(
429       pdl_interp::PDLInterpDialect::getMatcherFunctionName());
430   ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
431       pdl_interp::PDLInterpDialect::getRewriterModuleName());
432   assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
433 
434   // Allocate memory indices for the results of operations within the matcher
435   // and rewriters.
436   allocateMemoryIndices(matcherFunc, rewriterModule);
437 
438   // Generate code for the rewriter functions.
439   ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
440   for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
441     rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
442     for (Operation &op : rewriterFunc.getOps())
443       generate(&op, rewriterByteCodeWriter);
444   }
445   assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
446          "unexpected branches in rewriter function");
447 
448   // Generate code for the matcher function.
449   DenseMap<Block *, ByteCodeAddr> blockToAddr;
450   llvm::ReversePostOrderTraversal<Region *> rpot(&matcherFunc.getBody());
451   ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
452   for (Block *block : rpot) {
453     // Keep track of where this block begins within the matcher function.
454     blockToAddr.try_emplace(block, matcherByteCode.size());
455     for (Operation &op : *block)
456       generate(&op, matcherByteCodeWriter);
457   }
458 
459   // Resolve successor references in the matcher.
460   for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
461     ByteCodeAddr addr = blockToAddr[it.first];
462     for (unsigned offsetToFix : it.second)
463       std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
464   }
465 }
466 
allocateMemoryIndices(FuncOp matcherFunc,ModuleOp rewriterModule)467 void Generator::allocateMemoryIndices(FuncOp matcherFunc,
468                                       ModuleOp rewriterModule) {
469   // Rewriters use simplistic allocation scheme that simply assigns an index to
470   // each result.
471   for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
472     ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
473     auto processRewriterValue = [&](Value val) {
474       valueToMemIndex.try_emplace(val, index++);
475       if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) {
476         Type elementTy = rangeType.getElementType();
477         if (elementTy.isa<pdl::TypeType>())
478           valueToRangeIndex.try_emplace(val, typeRangeIndex++);
479         else if (elementTy.isa<pdl::ValueType>())
480           valueToRangeIndex.try_emplace(val, valueRangeIndex++);
481       }
482     };
483 
484     for (BlockArgument arg : rewriterFunc.getArguments())
485       processRewriterValue(arg);
486     rewriterFunc.getBody().walk([&](Operation *op) {
487       for (Value result : op->getResults())
488         processRewriterValue(result);
489     });
490     if (index > maxValueMemoryIndex)
491       maxValueMemoryIndex = index;
492     if (typeRangeIndex > maxTypeRangeMemoryIndex)
493       maxTypeRangeMemoryIndex = typeRangeIndex;
494     if (valueRangeIndex > maxValueRangeMemoryIndex)
495       maxValueRangeMemoryIndex = valueRangeIndex;
496   }
497 
498   // The matcher function uses a more sophisticated numbering that tries to
499   // minimize the number of memory indices assigned. This is done by determining
500   // a live range of the values within the matcher, then the allocation is just
501   // finding the minimal number of overlapping live ranges. This is essentially
502   // a simplified form of register allocation where we don't necessarily have a
503   // limited number of registers, but we still want to minimize the number used.
504   DenseMap<Operation *, ByteCodeField> opToIndex;
505   matcherFunc.getBody().walk([&](Operation *op) {
506     opToIndex.insert(std::make_pair(op, opToIndex.size()));
507   });
508 
509   // Liveness info for each of the defs within the matcher.
510   ByteCodeLiveRange::Allocator allocator;
511   DenseMap<Value, ByteCodeLiveRange> valueDefRanges;
512 
513   // Assign the root operation being matched to slot 0.
514   BlockArgument rootOpArg = matcherFunc.getArgument(0);
515   valueToMemIndex[rootOpArg] = 0;
516 
517   // Walk each of the blocks, computing the def interval that the value is used.
518   Liveness matcherLiveness(matcherFunc);
519   for (Block &block : matcherFunc.getBody()) {
520     const LivenessBlockInfo *info = matcherLiveness.getLiveness(&block);
521     assert(info && "expected liveness info for block");
522     auto processValue = [&](Value value, Operation *firstUseOrDef) {
523       // We don't need to process the root op argument, this value is always
524       // assigned to the first memory slot.
525       if (value == rootOpArg)
526         return;
527 
528       // Set indices for the range of this block that the value is used.
529       auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
530       defRangeIt->second.liveness.insert(
531           opToIndex[firstUseOrDef],
532           opToIndex[info->getEndOperation(value, firstUseOrDef)],
533           /*dummyValue*/ 0);
534 
535       // Check to see if this value is a range type.
536       if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) {
537         Type eleType = rangeTy.getElementType();
538         if (eleType.isa<pdl::TypeType>())
539           defRangeIt->second.typeRangeIndex = 0;
540         else if (eleType.isa<pdl::ValueType>())
541           defRangeIt->second.valueRangeIndex = 0;
542       }
543     };
544 
545     // Process the live-ins of this block.
546     for (Value liveIn : info->in())
547       processValue(liveIn, &block.front());
548 
549     // Process any new defs within this block.
550     for (Operation &op : block)
551       for (Value result : op.getResults())
552         processValue(result, &op);
553   }
554 
555   // Greedily allocate memory slots using the computed def live ranges.
556   std::vector<ByteCodeLiveRange> allocatedIndices;
557   ByteCodeField numIndices = 1, numTypeRanges = 0, numValueRanges = 0;
558   for (auto &defIt : valueDefRanges) {
559     ByteCodeField &memIndex = valueToMemIndex[defIt.first];
560     ByteCodeLiveRange &defRange = defIt.second;
561 
562     // Try to allocate to an existing index.
563     for (auto existingIndexIt : llvm::enumerate(allocatedIndices)) {
564       ByteCodeLiveRange &existingRange = existingIndexIt.value();
565       if (!defRange.overlaps(existingRange)) {
566         existingRange.unionWith(defRange);
567         memIndex = existingIndexIt.index() + 1;
568 
569         if (defRange.typeRangeIndex) {
570           if (!existingRange.typeRangeIndex)
571             existingRange.typeRangeIndex = numTypeRanges++;
572           valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
573         } else if (defRange.valueRangeIndex) {
574           if (!existingRange.valueRangeIndex)
575             existingRange.valueRangeIndex = numValueRanges++;
576           valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
577         }
578         break;
579       }
580     }
581 
582     // If no existing index could be used, add a new one.
583     if (memIndex == 0) {
584       allocatedIndices.emplace_back(allocator);
585       ByteCodeLiveRange &newRange = allocatedIndices.back();
586       newRange.unionWith(defRange);
587 
588       // Allocate an index for type/value ranges.
589       if (defRange.typeRangeIndex) {
590         newRange.typeRangeIndex = numTypeRanges;
591         valueToRangeIndex[defIt.first] = numTypeRanges++;
592       } else if (defRange.valueRangeIndex) {
593         newRange.valueRangeIndex = numValueRanges;
594         valueToRangeIndex[defIt.first] = numValueRanges++;
595       }
596 
597       memIndex = allocatedIndices.size();
598       ++numIndices;
599     }
600   }
601 
602   // Update the max number of indices.
603   if (numIndices > maxValueMemoryIndex)
604     maxValueMemoryIndex = numIndices;
605   if (numTypeRanges > maxTypeRangeMemoryIndex)
606     maxTypeRangeMemoryIndex = numTypeRanges;
607   if (numValueRanges > maxValueRangeMemoryIndex)
608     maxValueRangeMemoryIndex = numValueRanges;
609 }
610 
generate(Operation * op,ByteCodeWriter & writer)611 void Generator::generate(Operation *op, ByteCodeWriter &writer) {
612   TypeSwitch<Operation *>(op)
613       .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
614             pdl_interp::AreEqualOp, pdl_interp::BranchOp,
615             pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
616             pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
617             pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
618             pdl_interp::CreateAttributeOp, pdl_interp::CreateOperationOp,
619             pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp,
620             pdl_interp::EraseOp, pdl_interp::FinalizeOp,
621             pdl_interp::GetAttributeOp, pdl_interp::GetAttributeTypeOp,
622             pdl_interp::GetDefiningOpOp, pdl_interp::GetOperandOp,
623             pdl_interp::GetOperandsOp, pdl_interp::GetResultOp,
624             pdl_interp::GetResultsOp, pdl_interp::GetValueTypeOp,
625             pdl_interp::InferredTypesOp, pdl_interp::IsNotNullOp,
626             pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp,
627             pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp,
628             pdl_interp::SwitchTypesOp, pdl_interp::SwitchOperandCountOp,
629             pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>(
630           [&](auto interpOp) { this->generate(interpOp, writer); })
631       .Default([](Operation *) {
632         llvm_unreachable("unknown `pdl_interp` operation");
633       });
634 }
635 
generate(pdl_interp::ApplyConstraintOp op,ByteCodeWriter & writer)636 void Generator::generate(pdl_interp::ApplyConstraintOp op,
637                          ByteCodeWriter &writer) {
638   assert(constraintToMemIndex.count(op.name()) &&
639          "expected index for constraint function");
640   writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.name()],
641                 op.constParamsAttr());
642   writer.appendPDLValueList(op.args());
643   writer.append(op.getSuccessors());
644 }
generate(pdl_interp::ApplyRewriteOp op,ByteCodeWriter & writer)645 void Generator::generate(pdl_interp::ApplyRewriteOp op,
646                          ByteCodeWriter &writer) {
647   assert(externalRewriterToMemIndex.count(op.name()) &&
648          "expected index for rewrite function");
649   writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()],
650                 op.constParamsAttr());
651   writer.appendPDLValueList(op.args());
652 
653   ResultRange results = op.results();
654   writer.append(ByteCodeField(results.size()));
655   for (Value result : results) {
656     // In debug mode we also record the expected kind of the result, so that we
657     // can provide extra verification of the native rewrite function.
658 #ifndef NDEBUG
659     writer.appendPDLValueKind(result);
660 #endif
661 
662     // Range results also need to append the range storage index.
663     if (result.getType().isa<pdl::RangeType>())
664       writer.append(getRangeStorageIndex(result));
665     writer.append(result);
666   }
667 }
generate(pdl_interp::AreEqualOp op,ByteCodeWriter & writer)668 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
669   Value lhs = op.lhs();
670   if (lhs.getType().isa<pdl::RangeType>()) {
671     writer.append(OpCode::AreRangesEqual);
672     writer.appendPDLValueKind(lhs);
673     writer.append(op.lhs(), op.rhs(), op.getSuccessors());
674     return;
675   }
676 
677   writer.append(OpCode::AreEqual, lhs, op.rhs(), op.getSuccessors());
678 }
generate(pdl_interp::BranchOp op,ByteCodeWriter & writer)679 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
680   writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
681 }
generate(pdl_interp::CheckAttributeOp op,ByteCodeWriter & writer)682 void Generator::generate(pdl_interp::CheckAttributeOp op,
683                          ByteCodeWriter &writer) {
684   writer.append(OpCode::AreEqual, op.attribute(), op.constantValue(),
685                 op.getSuccessors());
686 }
generate(pdl_interp::CheckOperandCountOp op,ByteCodeWriter & writer)687 void Generator::generate(pdl_interp::CheckOperandCountOp op,
688                          ByteCodeWriter &writer) {
689   writer.append(OpCode::CheckOperandCount, op.operation(), op.count(),
690                 static_cast<ByteCodeField>(op.compareAtLeast()),
691                 op.getSuccessors());
692 }
generate(pdl_interp::CheckOperationNameOp op,ByteCodeWriter & writer)693 void Generator::generate(pdl_interp::CheckOperationNameOp op,
694                          ByteCodeWriter &writer) {
695   writer.append(OpCode::CheckOperationName, op.operation(),
696                 OperationName(op.name(), ctx), op.getSuccessors());
697 }
generate(pdl_interp::CheckResultCountOp op,ByteCodeWriter & writer)698 void Generator::generate(pdl_interp::CheckResultCountOp op,
699                          ByteCodeWriter &writer) {
700   writer.append(OpCode::CheckResultCount, op.operation(), op.count(),
701                 static_cast<ByteCodeField>(op.compareAtLeast()),
702                 op.getSuccessors());
703 }
generate(pdl_interp::CheckTypeOp op,ByteCodeWriter & writer)704 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
705   writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors());
706 }
generate(pdl_interp::CheckTypesOp op,ByteCodeWriter & writer)707 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
708   writer.append(OpCode::CheckTypes, op.value(), op.types(), op.getSuccessors());
709 }
generate(pdl_interp::CreateAttributeOp op,ByteCodeWriter & writer)710 void Generator::generate(pdl_interp::CreateAttributeOp op,
711                          ByteCodeWriter &writer) {
712   // Simply repoint the memory index of the result to the constant.
713   getMemIndex(op.attribute()) = getMemIndex(op.value());
714 }
generate(pdl_interp::CreateOperationOp op,ByteCodeWriter & writer)715 void Generator::generate(pdl_interp::CreateOperationOp op,
716                          ByteCodeWriter &writer) {
717   writer.append(OpCode::CreateOperation, op.operation(),
718                 OperationName(op.name(), ctx));
719   writer.appendPDLValueList(op.operands());
720 
721   // Add the attributes.
722   OperandRange attributes = op.attributes();
723   writer.append(static_cast<ByteCodeField>(attributes.size()));
724   for (auto it : llvm::zip(op.attributeNames(), op.attributes())) {
725     writer.append(
726         Identifier::get(std::get<0>(it).cast<StringAttr>().getValue(), ctx),
727         std::get<1>(it));
728   }
729   writer.appendPDLValueList(op.types());
730 }
generate(pdl_interp::CreateTypeOp op,ByteCodeWriter & writer)731 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
732   // Simply repoint the memory index of the result to the constant.
733   getMemIndex(op.result()) = getMemIndex(op.value());
734 }
generate(pdl_interp::CreateTypesOp op,ByteCodeWriter & writer)735 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
736   writer.append(OpCode::CreateTypes, op.result(),
737                 getRangeStorageIndex(op.result()), op.value());
738 }
generate(pdl_interp::EraseOp op,ByteCodeWriter & writer)739 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
740   writer.append(OpCode::EraseOp, op.operation());
741 }
generate(pdl_interp::FinalizeOp op,ByteCodeWriter & writer)742 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
743   writer.append(OpCode::Finalize);
744 }
generate(pdl_interp::GetAttributeOp op,ByteCodeWriter & writer)745 void Generator::generate(pdl_interp::GetAttributeOp op,
746                          ByteCodeWriter &writer) {
747   writer.append(OpCode::GetAttribute, op.attribute(), op.operation(),
748                 Identifier::get(op.name(), ctx));
749 }
generate(pdl_interp::GetAttributeTypeOp op,ByteCodeWriter & writer)750 void Generator::generate(pdl_interp::GetAttributeTypeOp op,
751                          ByteCodeWriter &writer) {
752   writer.append(OpCode::GetAttributeType, op.result(), op.value());
753 }
generate(pdl_interp::GetDefiningOpOp op,ByteCodeWriter & writer)754 void Generator::generate(pdl_interp::GetDefiningOpOp op,
755                          ByteCodeWriter &writer) {
756   writer.append(OpCode::GetDefiningOp, op.operation());
757   writer.appendPDLValue(op.value());
758 }
generate(pdl_interp::GetOperandOp op,ByteCodeWriter & writer)759 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
760   uint32_t index = op.index();
761   if (index < 4)
762     writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
763   else
764     writer.append(OpCode::GetOperandN, index);
765   writer.append(op.operation(), op.value());
766 }
generate(pdl_interp::GetOperandsOp op,ByteCodeWriter & writer)767 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
768   Value result = op.value();
769   Optional<uint32_t> index = op.index();
770   writer.append(OpCode::GetOperands,
771                 index.getValueOr(std::numeric_limits<uint32_t>::max()),
772                 op.operation());
773   if (result.getType().isa<pdl::RangeType>())
774     writer.append(getRangeStorageIndex(result));
775   else
776     writer.append(std::numeric_limits<ByteCodeField>::max());
777   writer.append(result);
778 }
generate(pdl_interp::GetResultOp op,ByteCodeWriter & writer)779 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
780   uint32_t index = op.index();
781   if (index < 4)
782     writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
783   else
784     writer.append(OpCode::GetResultN, index);
785   writer.append(op.operation(), op.value());
786 }
generate(pdl_interp::GetResultsOp op,ByteCodeWriter & writer)787 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
788   Value result = op.value();
789   Optional<uint32_t> index = op.index();
790   writer.append(OpCode::GetResults,
791                 index.getValueOr(std::numeric_limits<uint32_t>::max()),
792                 op.operation());
793   if (result.getType().isa<pdl::RangeType>())
794     writer.append(getRangeStorageIndex(result));
795   else
796     writer.append(std::numeric_limits<ByteCodeField>::max());
797   writer.append(result);
798 }
generate(pdl_interp::GetValueTypeOp op,ByteCodeWriter & writer)799 void Generator::generate(pdl_interp::GetValueTypeOp op,
800                          ByteCodeWriter &writer) {
801   if (op.getType().isa<pdl::RangeType>()) {
802     Value result = op.result();
803     writer.append(OpCode::GetValueRangeTypes, result,
804                   getRangeStorageIndex(result), op.value());
805   } else {
806     writer.append(OpCode::GetValueType, op.result(), op.value());
807   }
808 }
809 
generate(pdl_interp::InferredTypesOp op,ByteCodeWriter & writer)810 void Generator::generate(pdl_interp::InferredTypesOp op,
811                          ByteCodeWriter &writer) {
812   // InferType maps to a null type as a marker for inferring result types.
813   getMemIndex(op.type()) = getMemIndex(Type());
814 }
generate(pdl_interp::IsNotNullOp op,ByteCodeWriter & writer)815 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
816   writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors());
817 }
generate(pdl_interp::RecordMatchOp op,ByteCodeWriter & writer)818 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
819   ByteCodeField patternIndex = patterns.size();
820   patterns.emplace_back(PDLByteCodePattern::create(
821       op, rewriterToAddr[op.rewriter().getLeafReference()]));
822   writer.append(OpCode::RecordMatch, patternIndex,
823                 SuccessorRange(op.getOperation()), op.matchedOps());
824   writer.appendPDLValueList(op.inputs());
825 }
generate(pdl_interp::ReplaceOp op,ByteCodeWriter & writer)826 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
827   writer.append(OpCode::ReplaceOp, op.operation());
828   writer.appendPDLValueList(op.replValues());
829 }
generate(pdl_interp::SwitchAttributeOp op,ByteCodeWriter & writer)830 void Generator::generate(pdl_interp::SwitchAttributeOp op,
831                          ByteCodeWriter &writer) {
832   writer.append(OpCode::SwitchAttribute, op.attribute(), op.caseValuesAttr(),
833                 op.getSuccessors());
834 }
generate(pdl_interp::SwitchOperandCountOp op,ByteCodeWriter & writer)835 void Generator::generate(pdl_interp::SwitchOperandCountOp op,
836                          ByteCodeWriter &writer) {
837   writer.append(OpCode::SwitchOperandCount, op.operation(), op.caseValuesAttr(),
838                 op.getSuccessors());
839 }
generate(pdl_interp::SwitchOperationNameOp op,ByteCodeWriter & writer)840 void Generator::generate(pdl_interp::SwitchOperationNameOp op,
841                          ByteCodeWriter &writer) {
842   auto cases = llvm::map_range(op.caseValuesAttr(), [&](Attribute attr) {
843     return OperationName(attr.cast<StringAttr>().getValue(), ctx);
844   });
845   writer.append(OpCode::SwitchOperationName, op.operation(), cases,
846                 op.getSuccessors());
847 }
generate(pdl_interp::SwitchResultCountOp op,ByteCodeWriter & writer)848 void Generator::generate(pdl_interp::SwitchResultCountOp op,
849                          ByteCodeWriter &writer) {
850   writer.append(OpCode::SwitchResultCount, op.operation(), op.caseValuesAttr(),
851                 op.getSuccessors());
852 }
generate(pdl_interp::SwitchTypeOp op,ByteCodeWriter & writer)853 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
854   writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(),
855                 op.getSuccessors());
856 }
generate(pdl_interp::SwitchTypesOp op,ByteCodeWriter & writer)857 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
858   writer.append(OpCode::SwitchTypes, op.value(), op.caseValuesAttr(),
859                 op.getSuccessors());
860 }
861 
862 //===----------------------------------------------------------------------===//
863 // PDLByteCode
864 //===----------------------------------------------------------------------===//
865 
PDLByteCode(ModuleOp module,llvm::StringMap<PDLConstraintFunction> constraintFns,llvm::StringMap<PDLRewriteFunction> rewriteFns)866 PDLByteCode::PDLByteCode(ModuleOp module,
867                          llvm::StringMap<PDLConstraintFunction> constraintFns,
868                          llvm::StringMap<PDLRewriteFunction> rewriteFns) {
869   Generator generator(module.getContext(), uniquedData, matcherByteCode,
870                       rewriterByteCode, patterns, maxValueMemoryIndex,
871                       maxTypeRangeCount, maxValueRangeCount, constraintFns,
872                       rewriteFns);
873   generator.generate(module);
874 
875   // Initialize the external functions.
876   for (auto &it : constraintFns)
877     constraintFunctions.push_back(std::move(it.second));
878   for (auto &it : rewriteFns)
879     rewriteFunctions.push_back(std::move(it.second));
880 }
881 
882 /// Initialize the given state such that it can be used to execute the current
883 /// bytecode.
initializeMutableState(PDLByteCodeMutableState & state) const884 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
885   state.memory.resize(maxValueMemoryIndex, nullptr);
886   state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange());
887   state.valueRangeMemory.resize(maxValueRangeCount, ValueRange());
888   state.currentPatternBenefits.reserve(patterns.size());
889   for (const PDLByteCodePattern &pattern : patterns)
890     state.currentPatternBenefits.push_back(pattern.getBenefit());
891 }
892 
893 //===----------------------------------------------------------------------===//
894 // ByteCode Execution
895 
896 namespace {
897 /// This class provides support for executing a bytecode stream.
898 class ByteCodeExecutor {
899 public:
ByteCodeExecutor(const ByteCodeField * curCodeIt,MutableArrayRef<const void * > memory,MutableArrayRef<TypeRange> typeRangeMemory,std::vector<llvm::OwningArrayRef<Type>> & allocatedTypeRangeMemory,MutableArrayRef<ValueRange> valueRangeMemory,std::vector<llvm::OwningArrayRef<Value>> & allocatedValueRangeMemory,ArrayRef<const void * > uniquedMemory,ArrayRef<ByteCodeField> code,ArrayRef<PatternBenefit> currentPatternBenefits,ArrayRef<PDLByteCodePattern> patterns,ArrayRef<PDLConstraintFunction> constraintFunctions,ArrayRef<PDLRewriteFunction> rewriteFunctions)900   ByteCodeExecutor(
901       const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory,
902       MutableArrayRef<TypeRange> typeRangeMemory,
903       std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
904       MutableArrayRef<ValueRange> valueRangeMemory,
905       std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
906       ArrayRef<const void *> uniquedMemory, ArrayRef<ByteCodeField> code,
907       ArrayRef<PatternBenefit> currentPatternBenefits,
908       ArrayRef<PDLByteCodePattern> patterns,
909       ArrayRef<PDLConstraintFunction> constraintFunctions,
910       ArrayRef<PDLRewriteFunction> rewriteFunctions)
911       : curCodeIt(curCodeIt), memory(memory), typeRangeMemory(typeRangeMemory),
912         allocatedTypeRangeMemory(allocatedTypeRangeMemory),
913         valueRangeMemory(valueRangeMemory),
914         allocatedValueRangeMemory(allocatedValueRangeMemory),
915         uniquedMemory(uniquedMemory), code(code),
916         currentPatternBenefits(currentPatternBenefits), patterns(patterns),
917         constraintFunctions(constraintFunctions),
918         rewriteFunctions(rewriteFunctions) {}
919 
920   /// Start executing the code at the current bytecode index. `matches` is an
921   /// optional field provided when this function is executed in a matching
922   /// context.
923   void execute(PatternRewriter &rewriter,
924                SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
925                Optional<Location> mainRewriteLoc = {});
926 
927 private:
928   /// Internal implementation of executing each of the bytecode commands.
929   void executeApplyConstraint(PatternRewriter &rewriter);
930   void executeApplyRewrite(PatternRewriter &rewriter);
931   void executeAreEqual();
932   void executeAreRangesEqual();
933   void executeBranch();
934   void executeCheckOperandCount();
935   void executeCheckOperationName();
936   void executeCheckResultCount();
937   void executeCheckTypes();
938   void executeCreateOperation(PatternRewriter &rewriter,
939                               Location mainRewriteLoc);
940   void executeCreateTypes();
941   void executeEraseOp(PatternRewriter &rewriter);
942   void executeGetAttribute();
943   void executeGetAttributeType();
944   void executeGetDefiningOp();
945   void executeGetOperand(unsigned index);
946   void executeGetOperands();
947   void executeGetResult(unsigned index);
948   void executeGetResults();
949   void executeGetValueType();
950   void executeGetValueRangeTypes();
951   void executeIsNotNull();
952   void executeRecordMatch(PatternRewriter &rewriter,
953                           SmallVectorImpl<PDLByteCode::MatchResult> &matches);
954   void executeReplaceOp(PatternRewriter &rewriter);
955   void executeSwitchAttribute();
956   void executeSwitchOperandCount();
957   void executeSwitchOperationName();
958   void executeSwitchResultCount();
959   void executeSwitchType();
960   void executeSwitchTypes();
961 
962   /// Read a value from the bytecode buffer, optionally skipping a certain
963   /// number of prefix values. These methods always update the buffer to point
964   /// to the next field after the read data.
965   template <typename T = ByteCodeField>
read(size_t skipN=0)966   T read(size_t skipN = 0) {
967     curCodeIt += skipN;
968     return readImpl<T>();
969   }
read(size_t skipN=0)970   ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
971 
972   /// Read a list of values from the bytecode buffer.
973   template <typename ValueT, typename T>
readList(SmallVectorImpl<T> & list)974   void readList(SmallVectorImpl<T> &list) {
975     list.clear();
976     for (unsigned i = 0, e = read(); i != e; ++i)
977       list.push_back(read<ValueT>());
978   }
979 
980   /// Read a list of values from the bytecode buffer. The values may be encoded
981   /// as either Value or ValueRange elements.
readValueList(SmallVectorImpl<Value> & list)982   void readValueList(SmallVectorImpl<Value> &list) {
983     for (unsigned i = 0, e = read(); i != e; ++i) {
984       if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
985         list.push_back(read<Value>());
986       } else {
987         ValueRange *values = read<ValueRange *>();
988         list.append(values->begin(), values->end());
989       }
990     }
991   }
992 
993   /// Jump to a specific successor based on a predicate value.
selectJump(bool isTrue)994   void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
995   /// Jump to a specific successor based on a destination index.
selectJump(size_t destIndex)996   void selectJump(size_t destIndex) {
997     curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
998   }
999 
1000   /// Handle a switch operation with the provided value and cases.
1001   template <typename T, typename RangeT, typename Comparator = std::equal_to<T>>
handleSwitch(const T & value,RangeT && cases,Comparator cmp={})1002   void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) {
1003     LLVM_DEBUG({
1004       llvm::dbgs() << "  * Value: " << value << "\n"
1005                    << "  * Cases: ";
1006       llvm::interleaveComma(cases, llvm::dbgs());
1007       llvm::dbgs() << "\n";
1008     });
1009 
1010     // Check to see if the attribute value is within the case list. Jump to
1011     // the correct successor index based on the result.
1012     for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
1013       if (cmp(*it, value))
1014         return selectJump(size_t((it - cases.begin()) + 1));
1015     selectJump(size_t(0));
1016   }
1017 
1018   /// Internal implementation of reading various data types from the bytecode
1019   /// stream.
1020   template <typename T>
readFromMemory()1021   const void *readFromMemory() {
1022     size_t index = *curCodeIt++;
1023 
1024     // If this type is an SSA value, it can only be stored in non-const memory.
1025     if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *,
1026                         Value>::value ||
1027         index < memory.size())
1028       return memory[index];
1029 
1030     // Otherwise, if this index is not inbounds it is uniqued.
1031     return uniquedMemory[index - memory.size()];
1032   }
1033   template <typename T>
readImpl()1034   std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
1035     return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
1036   }
1037   template <typename T>
1038   std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
1039                    T>
readImpl()1040   readImpl() {
1041     return T(T::getFromOpaquePointer(readFromMemory<T>()));
1042   }
1043   template <typename T>
readImpl()1044   std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
1045     switch (read<PDLValue::Kind>()) {
1046     case PDLValue::Kind::Attribute:
1047       return read<Attribute>();
1048     case PDLValue::Kind::Operation:
1049       return read<Operation *>();
1050     case PDLValue::Kind::Type:
1051       return read<Type>();
1052     case PDLValue::Kind::Value:
1053       return read<Value>();
1054     case PDLValue::Kind::TypeRange:
1055       return read<TypeRange *>();
1056     case PDLValue::Kind::ValueRange:
1057       return read<ValueRange *>();
1058     }
1059     llvm_unreachable("unhandled PDLValue::Kind");
1060   }
1061   template <typename T>
readImpl()1062   std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
1063     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
1064                   "unexpected ByteCode address size");
1065     ByteCodeAddr result;
1066     std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
1067     curCodeIt += 2;
1068     return result;
1069   }
1070   template <typename T>
readImpl()1071   std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
1072     return *curCodeIt++;
1073   }
1074   template <typename T>
readImpl()1075   std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() {
1076     return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>());
1077   }
1078 
1079   /// The underlying bytecode buffer.
1080   const ByteCodeField *curCodeIt;
1081 
1082   /// The current execution memory.
1083   MutableArrayRef<const void *> memory;
1084   MutableArrayRef<TypeRange> typeRangeMemory;
1085   std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
1086   MutableArrayRef<ValueRange> valueRangeMemory;
1087   std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
1088 
1089   /// References to ByteCode data necessary for execution.
1090   ArrayRef<const void *> uniquedMemory;
1091   ArrayRef<ByteCodeField> code;
1092   ArrayRef<PatternBenefit> currentPatternBenefits;
1093   ArrayRef<PDLByteCodePattern> patterns;
1094   ArrayRef<PDLConstraintFunction> constraintFunctions;
1095   ArrayRef<PDLRewriteFunction> rewriteFunctions;
1096 };
1097 
1098 /// This class is an instantiation of the PDLResultList that provides access to
1099 /// the returned results. This API is not on `PDLResultList` to avoid
1100 /// overexposing access to information specific solely to the ByteCode.
1101 class ByteCodeRewriteResultList : public PDLResultList {
1102 public:
ByteCodeRewriteResultList(unsigned maxNumResults)1103   ByteCodeRewriteResultList(unsigned maxNumResults)
1104       : PDLResultList(maxNumResults) {}
1105 
1106   /// Return the list of PDL results.
getResults()1107   MutableArrayRef<PDLValue> getResults() { return results; }
1108 
1109   /// Return the type ranges allocated by this list.
getAllocatedTypeRanges()1110   MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
1111     return allocatedTypeRanges;
1112   }
1113 
1114   /// Return the value ranges allocated by this list.
getAllocatedValueRanges()1115   MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
1116     return allocatedValueRanges;
1117   }
1118 };
1119 } // end anonymous namespace
1120 
executeApplyConstraint(PatternRewriter & rewriter)1121 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
1122   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
1123   const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
1124   ArrayAttr constParams = read<ArrayAttr>();
1125   SmallVector<PDLValue, 16> args;
1126   readList<PDLValue>(args);
1127 
1128   LLVM_DEBUG({
1129     llvm::dbgs() << "  * Arguments: ";
1130     llvm::interleaveComma(args, llvm::dbgs());
1131     llvm::dbgs() << "\n  * Parameters: " << constParams << "\n";
1132   });
1133 
1134   // Invoke the constraint and jump to the proper destination.
1135   selectJump(succeeded(constraintFn(args, constParams, rewriter)));
1136 }
1137 
executeApplyRewrite(PatternRewriter & rewriter)1138 void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
1139   LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
1140   const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
1141   ArrayAttr constParams = read<ArrayAttr>();
1142   SmallVector<PDLValue, 16> args;
1143   readList<PDLValue>(args);
1144 
1145   LLVM_DEBUG({
1146     llvm::dbgs() << "  * Arguments: ";
1147     llvm::interleaveComma(args, llvm::dbgs());
1148     llvm::dbgs() << "\n  * Parameters: " << constParams << "\n";
1149   });
1150 
1151   // Execute the rewrite function.
1152   ByteCodeField numResults = read();
1153   ByteCodeRewriteResultList results(numResults);
1154   rewriteFn(args, constParams, rewriter, results);
1155 
1156   assert(results.getResults().size() == numResults &&
1157          "native PDL rewrite function returned unexpected number of results");
1158 
1159   // Store the results in the bytecode memory.
1160   for (PDLValue &result : results.getResults()) {
1161     LLVM_DEBUG(llvm::dbgs() << "  * Result: " << result << "\n");
1162 
1163 // In debug mode we also verify the expected kind of the result.
1164 #ifndef NDEBUG
1165     assert(result.getKind() == read<PDLValue::Kind>() &&
1166            "native PDL rewrite function returned an unexpected type of result");
1167 #endif
1168 
1169     // If the result is a range, we need to copy it over to the bytecodes
1170     // range memory.
1171     if (Optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
1172       unsigned rangeIndex = read();
1173       typeRangeMemory[rangeIndex] = *typeRange;
1174       memory[read()] = &typeRangeMemory[rangeIndex];
1175     } else if (Optional<ValueRange> valueRange =
1176                    result.dyn_cast<ValueRange>()) {
1177       unsigned rangeIndex = read();
1178       valueRangeMemory[rangeIndex] = *valueRange;
1179       memory[read()] = &valueRangeMemory[rangeIndex];
1180     } else {
1181       memory[read()] = result.getAsOpaquePointer();
1182     }
1183   }
1184 
1185   // Copy over any underlying storage allocated for result ranges.
1186   for (auto &it : results.getAllocatedTypeRanges())
1187     allocatedTypeRangeMemory.push_back(std::move(it));
1188   for (auto &it : results.getAllocatedValueRanges())
1189     allocatedValueRangeMemory.push_back(std::move(it));
1190 }
1191 
executeAreEqual()1192 void ByteCodeExecutor::executeAreEqual() {
1193   LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1194   const void *lhs = read<const void *>();
1195   const void *rhs = read<const void *>();
1196 
1197   LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n");
1198   selectJump(lhs == rhs);
1199 }
1200 
executeAreRangesEqual()1201 void ByteCodeExecutor::executeAreRangesEqual() {
1202   LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n");
1203   PDLValue::Kind valueKind = read<PDLValue::Kind>();
1204   const void *lhs = read<const void *>();
1205   const void *rhs = read<const void *>();
1206 
1207   switch (valueKind) {
1208   case PDLValue::Kind::TypeRange: {
1209     const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs);
1210     const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs);
1211     LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1212     selectJump(*lhsRange == *rhsRange);
1213     break;
1214   }
1215   case PDLValue::Kind::ValueRange: {
1216     const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs);
1217     const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs);
1218     LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1219     selectJump(*lhsRange == *rhsRange);
1220     break;
1221   }
1222   default:
1223     llvm_unreachable("unexpected `AreRangesEqual` value kind");
1224   }
1225 }
1226 
executeBranch()1227 void ByteCodeExecutor::executeBranch() {
1228   LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n");
1229   curCodeIt = &code[read<ByteCodeAddr>()];
1230 }
1231 
executeCheckOperandCount()1232 void ByteCodeExecutor::executeCheckOperandCount() {
1233   LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
1234   Operation *op = read<Operation *>();
1235   uint32_t expectedCount = read<uint32_t>();
1236   bool compareAtLeast = read();
1237 
1238   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumOperands() << "\n"
1239                           << "  * Expected: " << expectedCount << "\n"
1240                           << "  * Comparator: "
1241                           << (compareAtLeast ? ">=" : "==") << "\n");
1242   if (compareAtLeast)
1243     selectJump(op->getNumOperands() >= expectedCount);
1244   else
1245     selectJump(op->getNumOperands() == expectedCount);
1246 }
1247 
executeCheckOperationName()1248 void ByteCodeExecutor::executeCheckOperationName() {
1249   LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
1250   Operation *op = read<Operation *>();
1251   OperationName expectedName = read<OperationName>();
1252 
1253   LLVM_DEBUG(llvm::dbgs() << "  * Found: \"" << op->getName() << "\"\n"
1254                           << "  * Expected: \"" << expectedName << "\"\n");
1255   selectJump(op->getName() == expectedName);
1256 }
1257 
executeCheckResultCount()1258 void ByteCodeExecutor::executeCheckResultCount() {
1259   LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
1260   Operation *op = read<Operation *>();
1261   uint32_t expectedCount = read<uint32_t>();
1262   bool compareAtLeast = read();
1263 
1264   LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumResults() << "\n"
1265                           << "  * Expected: " << expectedCount << "\n"
1266                           << "  * Comparator: "
1267                           << (compareAtLeast ? ">=" : "==") << "\n");
1268   if (compareAtLeast)
1269     selectJump(op->getNumResults() >= expectedCount);
1270   else
1271     selectJump(op->getNumResults() == expectedCount);
1272 }
1273 
executeCheckTypes()1274 void ByteCodeExecutor::executeCheckTypes() {
1275   LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1276   TypeRange *lhs = read<TypeRange *>();
1277   Attribute rhs = read<Attribute>();
1278   LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
1279 
1280   selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>());
1281 }
1282 
executeCreateTypes()1283 void ByteCodeExecutor::executeCreateTypes() {
1284   LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n");
1285   unsigned memIndex = read();
1286   unsigned rangeIndex = read();
1287   ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>();
1288 
1289   LLVM_DEBUG(llvm::dbgs() << "  * Types: " << typesAttr << "\n\n");
1290 
1291   // Allocate a buffer for this type range.
1292   llvm::OwningArrayRef<Type> storage(typesAttr.size());
1293   llvm::copy(typesAttr.getAsValueRange<TypeAttr>(), storage.begin());
1294   allocatedTypeRangeMemory.emplace_back(std::move(storage));
1295 
1296   // Assign this to the range slot and use the range as the value for the
1297   // memory index.
1298   typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back();
1299   memory[memIndex] = &typeRangeMemory[rangeIndex];
1300 }
1301 
executeCreateOperation(PatternRewriter & rewriter,Location mainRewriteLoc)1302 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
1303                                               Location mainRewriteLoc) {
1304   LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
1305 
1306   unsigned memIndex = read();
1307   OperationState state(mainRewriteLoc, read<OperationName>());
1308   readValueList(state.operands);
1309   for (unsigned i = 0, e = read(); i != e; ++i) {
1310     Identifier name = read<Identifier>();
1311     if (Attribute attr = read<Attribute>())
1312       state.addAttribute(name, attr);
1313   }
1314 
1315   for (unsigned i = 0, e = read(); i != e; ++i) {
1316     if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1317       state.types.push_back(read<Type>());
1318       continue;
1319     }
1320 
1321     // If we find a null range, this signals that the types are infered.
1322     if (TypeRange *resultTypes = read<TypeRange *>()) {
1323       state.types.append(resultTypes->begin(), resultTypes->end());
1324       continue;
1325     }
1326 
1327     // Handle the case where the operation has inferred types.
1328     InferTypeOpInterface::Concept *concept =
1329         state.name.getAbstractOperation()->getInterface<InferTypeOpInterface>();
1330 
1331     // TODO: Handle failure.
1332     state.types.clear();
1333     if (failed(concept->inferReturnTypes(
1334             state.getContext(), state.location, state.operands,
1335             state.attributes.getDictionary(state.getContext()), state.regions,
1336             state.types)))
1337       return;
1338     break;
1339   }
1340 
1341   Operation *resultOp = rewriter.createOperation(state);
1342   memory[memIndex] = resultOp;
1343 
1344   LLVM_DEBUG({
1345     llvm::dbgs() << "  * Attributes: "
1346                  << state.attributes.getDictionary(state.getContext())
1347                  << "\n  * Operands: ";
1348     llvm::interleaveComma(state.operands, llvm::dbgs());
1349     llvm::dbgs() << "\n  * Result Types: ";
1350     llvm::interleaveComma(state.types, llvm::dbgs());
1351     llvm::dbgs() << "\n  * Result: " << *resultOp << "\n";
1352   });
1353 }
1354 
executeEraseOp(PatternRewriter & rewriter)1355 void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
1356   LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
1357   Operation *op = read<Operation *>();
1358 
1359   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1360   rewriter.eraseOp(op);
1361 }
1362 
executeGetAttribute()1363 void ByteCodeExecutor::executeGetAttribute() {
1364   LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
1365   unsigned memIndex = read();
1366   Operation *op = read<Operation *>();
1367   Identifier attrName = read<Identifier>();
1368   Attribute attr = op->getAttr(attrName);
1369 
1370   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1371                           << "  * Attribute: " << attrName << "\n"
1372                           << "  * Result: " << attr << "\n");
1373   memory[memIndex] = attr.getAsOpaquePointer();
1374 }
1375 
executeGetAttributeType()1376 void ByteCodeExecutor::executeGetAttributeType() {
1377   LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
1378   unsigned memIndex = read();
1379   Attribute attr = read<Attribute>();
1380   Type type = attr ? attr.getType() : Type();
1381 
1382   LLVM_DEBUG(llvm::dbgs() << "  * Attribute: " << attr << "\n"
1383                           << "  * Result: " << type << "\n");
1384   memory[memIndex] = type.getAsOpaquePointer();
1385 }
1386 
executeGetDefiningOp()1387 void ByteCodeExecutor::executeGetDefiningOp() {
1388   LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
1389   unsigned memIndex = read();
1390   Operation *op = nullptr;
1391   if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1392     Value value = read<Value>();
1393     if (value)
1394       op = value.getDefiningOp();
1395     LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1396   } else {
1397     ValueRange *values = read<ValueRange *>();
1398     if (values && !values->empty()) {
1399       op = values->front().getDefiningOp();
1400     }
1401     LLVM_DEBUG(llvm::dbgs() << "  * Values: " << values << "\n");
1402   }
1403 
1404   LLVM_DEBUG(llvm::dbgs() << "  * Result: " << op << "\n");
1405   memory[memIndex] = op;
1406 }
1407 
executeGetOperand(unsigned index)1408 void ByteCodeExecutor::executeGetOperand(unsigned index) {
1409   Operation *op = read<Operation *>();
1410   unsigned memIndex = read();
1411   Value operand =
1412       index < op->getNumOperands() ? op->getOperand(index) : Value();
1413 
1414   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1415                           << "  * Index: " << index << "\n"
1416                           << "  * Result: " << operand << "\n");
1417   memory[memIndex] = operand.getAsOpaquePointer();
1418 }
1419 
1420 /// This function is the internal implementation of `GetResults` and
1421 /// `GetOperands` that provides support for extracting a value range from the
1422 /// given operation.
1423 template <template <typename> class AttrSizedSegmentsT, typename RangeT>
1424 static void *
executeGetOperandsResults(RangeT values,Operation * op,unsigned index,ByteCodeField rangeIndex,StringRef attrSizedSegments,MutableArrayRef<ValueRange> & valueRangeMemory)1425 executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
1426                           ByteCodeField rangeIndex, StringRef attrSizedSegments,
1427                           MutableArrayRef<ValueRange> &valueRangeMemory) {
1428   // Check for the sentinel index that signals that all values should be
1429   // returned.
1430   if (index == std::numeric_limits<uint32_t>::max()) {
1431     LLVM_DEBUG(llvm::dbgs() << "  * Getting all values\n");
1432     // `values` is already the full value range.
1433 
1434     // Otherwise, check to see if this operation uses AttrSizedSegments.
1435   } else if (op->hasTrait<AttrSizedSegmentsT>()) {
1436     LLVM_DEBUG(llvm::dbgs()
1437                << "  * Extracting values from `" << attrSizedSegments << "`\n");
1438 
1439     auto segmentAttr = op->getAttrOfType<DenseElementsAttr>(attrSizedSegments);
1440     if (!segmentAttr || segmentAttr.getNumElements() <= index)
1441       return nullptr;
1442 
1443     auto segments = segmentAttr.getValues<int32_t>();
1444     unsigned startIndex =
1445         std::accumulate(segments.begin(), segments.begin() + index, 0);
1446     values = values.slice(startIndex, *std::next(segments.begin(), index));
1447 
1448     LLVM_DEBUG(llvm::dbgs() << "  * Extracting range[" << startIndex << ", "
1449                             << *std::next(segments.begin(), index) << "]\n");
1450 
1451     // Otherwise, assume this is the last operand group of the operation.
1452     // FIXME: We currently don't support operations with
1453     // SameVariadicOperandSize/SameVariadicResultSize here given that we don't
1454     // have a way to detect it's presence.
1455   } else if (values.size() >= index) {
1456     LLVM_DEBUG(llvm::dbgs()
1457                << "  * Treating values as trailing variadic range\n");
1458     values = values.drop_front(index);
1459 
1460     // If we couldn't detect a way to compute the values, bail out.
1461   } else {
1462     return nullptr;
1463   }
1464 
1465   // If the range index is valid, we are returning a range.
1466   if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) {
1467     valueRangeMemory[rangeIndex] = values;
1468     return &valueRangeMemory[rangeIndex];
1469   }
1470 
1471   // If a range index wasn't provided, the range is required to be non-variadic.
1472   return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
1473 }
1474 
executeGetOperands()1475 void ByteCodeExecutor::executeGetOperands() {
1476   LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n");
1477   unsigned index = read<uint32_t>();
1478   Operation *op = read<Operation *>();
1479   ByteCodeField rangeIndex = read();
1480 
1481   void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>(
1482       op->getOperands(), op, index, rangeIndex, "operand_segment_sizes",
1483       valueRangeMemory);
1484   if (!result)
1485     LLVM_DEBUG(llvm::dbgs() << "  * Invalid operand range\n");
1486   memory[read()] = result;
1487 }
1488 
executeGetResult(unsigned index)1489 void ByteCodeExecutor::executeGetResult(unsigned index) {
1490   Operation *op = read<Operation *>();
1491   unsigned memIndex = read();
1492   OpResult result =
1493       index < op->getNumResults() ? op->getResult(index) : OpResult();
1494 
1495   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1496                           << "  * Index: " << index << "\n"
1497                           << "  * Result: " << result << "\n");
1498   memory[memIndex] = result.getAsOpaquePointer();
1499 }
1500 
executeGetResults()1501 void ByteCodeExecutor::executeGetResults() {
1502   LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n");
1503   unsigned index = read<uint32_t>();
1504   Operation *op = read<Operation *>();
1505   ByteCodeField rangeIndex = read();
1506 
1507   void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>(
1508       op->getResults(), op, index, rangeIndex, "result_segment_sizes",
1509       valueRangeMemory);
1510   if (!result)
1511     LLVM_DEBUG(llvm::dbgs() << "  * Invalid result range\n");
1512   memory[read()] = result;
1513 }
1514 
executeGetValueType()1515 void ByteCodeExecutor::executeGetValueType() {
1516   LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
1517   unsigned memIndex = read();
1518   Value value = read<Value>();
1519   Type type = value ? value.getType() : Type();
1520 
1521   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n"
1522                           << "  * Result: " << type << "\n");
1523   memory[memIndex] = type.getAsOpaquePointer();
1524 }
1525 
executeGetValueRangeTypes()1526 void ByteCodeExecutor::executeGetValueRangeTypes() {
1527   LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n");
1528   unsigned memIndex = read();
1529   unsigned rangeIndex = read();
1530   ValueRange *values = read<ValueRange *>();
1531   if (!values) {
1532     LLVM_DEBUG(llvm::dbgs() << "  * Values: <NULL>\n\n");
1533     memory[memIndex] = nullptr;
1534     return;
1535   }
1536 
1537   LLVM_DEBUG({
1538     llvm::dbgs() << "  * Values (" << values->size() << "): ";
1539     llvm::interleaveComma(*values, llvm::dbgs());
1540     llvm::dbgs() << "\n  * Result: ";
1541     llvm::interleaveComma(values->getType(), llvm::dbgs());
1542     llvm::dbgs() << "\n";
1543   });
1544   typeRangeMemory[rangeIndex] = values->getType();
1545   memory[memIndex] = &typeRangeMemory[rangeIndex];
1546 }
1547 
executeIsNotNull()1548 void ByteCodeExecutor::executeIsNotNull() {
1549   LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
1550   const void *value = read<const void *>();
1551 
1552   LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n");
1553   selectJump(value != nullptr);
1554 }
1555 
executeRecordMatch(PatternRewriter & rewriter,SmallVectorImpl<PDLByteCode::MatchResult> & matches)1556 void ByteCodeExecutor::executeRecordMatch(
1557     PatternRewriter &rewriter,
1558     SmallVectorImpl<PDLByteCode::MatchResult> &matches) {
1559   LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
1560   unsigned patternIndex = read();
1561   PatternBenefit benefit = currentPatternBenefits[patternIndex];
1562   const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
1563 
1564   // If the benefit of the pattern is impossible, skip the processing of the
1565   // rest of the pattern.
1566   if (benefit.isImpossibleToMatch()) {
1567     LLVM_DEBUG(llvm::dbgs() << "  * Benefit: Impossible To Match\n");
1568     curCodeIt = dest;
1569     return;
1570   }
1571 
1572   // Create a fused location containing the locations of each of the
1573   // operations used in the match. This will be used as the location for
1574   // created operations during the rewrite that don't already have an
1575   // explicit location set.
1576   unsigned numMatchLocs = read();
1577   SmallVector<Location, 4> matchLocs;
1578   matchLocs.reserve(numMatchLocs);
1579   for (unsigned i = 0; i != numMatchLocs; ++i)
1580     matchLocs.push_back(read<Operation *>()->getLoc());
1581   Location matchLoc = rewriter.getFusedLoc(matchLocs);
1582 
1583   LLVM_DEBUG(llvm::dbgs() << "  * Benefit: " << benefit.getBenefit() << "\n"
1584                           << "  * Location: " << matchLoc << "\n");
1585   matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
1586   PDLByteCode::MatchResult &match = matches.back();
1587 
1588   // Record all of the inputs to the match. If any of the inputs are ranges, we
1589   // will also need to remap the range pointer to memory stored in the match
1590   // state.
1591   unsigned numInputs = read();
1592   match.values.reserve(numInputs);
1593   match.typeRangeValues.reserve(numInputs);
1594   match.valueRangeValues.reserve(numInputs);
1595   for (unsigned i = 0; i < numInputs; ++i) {
1596     switch (read<PDLValue::Kind>()) {
1597     case PDLValue::Kind::TypeRange:
1598       match.typeRangeValues.push_back(*read<TypeRange *>());
1599       match.values.push_back(&match.typeRangeValues.back());
1600       break;
1601     case PDLValue::Kind::ValueRange:
1602       match.valueRangeValues.push_back(*read<ValueRange *>());
1603       match.values.push_back(&match.valueRangeValues.back());
1604       break;
1605     default:
1606       match.values.push_back(read<const void *>());
1607       break;
1608     }
1609   }
1610   curCodeIt = dest;
1611 }
1612 
executeReplaceOp(PatternRewriter & rewriter)1613 void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
1614   LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
1615   Operation *op = read<Operation *>();
1616   SmallVector<Value, 16> args;
1617   readValueList(args);
1618 
1619   LLVM_DEBUG({
1620     llvm::dbgs() << "  * Operation: " << *op << "\n"
1621                  << "  * Values: ";
1622     llvm::interleaveComma(args, llvm::dbgs());
1623     llvm::dbgs() << "\n";
1624   });
1625   rewriter.replaceOp(op, args);
1626 }
1627 
executeSwitchAttribute()1628 void ByteCodeExecutor::executeSwitchAttribute() {
1629   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
1630   Attribute value = read<Attribute>();
1631   ArrayAttr cases = read<ArrayAttr>();
1632   handleSwitch(value, cases);
1633 }
1634 
executeSwitchOperandCount()1635 void ByteCodeExecutor::executeSwitchOperandCount() {
1636   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
1637   Operation *op = read<Operation *>();
1638   auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1639 
1640   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1641   handleSwitch(op->getNumOperands(), cases);
1642 }
1643 
executeSwitchOperationName()1644 void ByteCodeExecutor::executeSwitchOperationName() {
1645   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
1646   OperationName value = read<Operation *>()->getName();
1647   size_t caseCount = read();
1648 
1649   // The operation names are stored in-line, so to print them out for
1650   // debugging purposes we need to read the array before executing the
1651   // switch so that we can display all of the possible values.
1652   LLVM_DEBUG({
1653     const ByteCodeField *prevCodeIt = curCodeIt;
1654     llvm::dbgs() << "  * Value: " << value << "\n"
1655                  << "  * Cases: ";
1656     llvm::interleaveComma(
1657         llvm::map_range(llvm::seq<size_t>(0, caseCount),
1658                         [&](size_t) { return read<OperationName>(); }),
1659         llvm::dbgs());
1660     llvm::dbgs() << "\n";
1661     curCodeIt = prevCodeIt;
1662   });
1663 
1664   // Try to find the switch value within any of the cases.
1665   for (size_t i = 0; i != caseCount; ++i) {
1666     if (read<OperationName>() == value) {
1667       curCodeIt += (caseCount - i - 1);
1668       return selectJump(i + 1);
1669     }
1670   }
1671   selectJump(size_t(0));
1672 }
1673 
executeSwitchResultCount()1674 void ByteCodeExecutor::executeSwitchResultCount() {
1675   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
1676   Operation *op = read<Operation *>();
1677   auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1678 
1679   LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1680   handleSwitch(op->getNumResults(), cases);
1681 }
1682 
executeSwitchType()1683 void ByteCodeExecutor::executeSwitchType() {
1684   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
1685   Type value = read<Type>();
1686   auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
1687   handleSwitch(value, cases);
1688 }
1689 
executeSwitchTypes()1690 void ByteCodeExecutor::executeSwitchTypes() {
1691   LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n");
1692   TypeRange *value = read<TypeRange *>();
1693   auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
1694   if (!value) {
1695     LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n");
1696     return selectJump(size_t(0));
1697   }
1698   handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) {
1699     return value == caseValue.getAsValueRange<TypeAttr>();
1700   });
1701 }
1702 
execute(PatternRewriter & rewriter,SmallVectorImpl<PDLByteCode::MatchResult> * matches,Optional<Location> mainRewriteLoc)1703 void ByteCodeExecutor::execute(
1704     PatternRewriter &rewriter,
1705     SmallVectorImpl<PDLByteCode::MatchResult> *matches,
1706     Optional<Location> mainRewriteLoc) {
1707   while (true) {
1708     OpCode opCode = static_cast<OpCode>(read());
1709     switch (opCode) {
1710     case ApplyConstraint:
1711       executeApplyConstraint(rewriter);
1712       break;
1713     case ApplyRewrite:
1714       executeApplyRewrite(rewriter);
1715       break;
1716     case AreEqual:
1717       executeAreEqual();
1718       break;
1719     case AreRangesEqual:
1720       executeAreRangesEqual();
1721       break;
1722     case Branch:
1723       executeBranch();
1724       break;
1725     case CheckOperandCount:
1726       executeCheckOperandCount();
1727       break;
1728     case CheckOperationName:
1729       executeCheckOperationName();
1730       break;
1731     case CheckResultCount:
1732       executeCheckResultCount();
1733       break;
1734     case CheckTypes:
1735       executeCheckTypes();
1736       break;
1737     case CreateOperation:
1738       executeCreateOperation(rewriter, *mainRewriteLoc);
1739       break;
1740     case CreateTypes:
1741       executeCreateTypes();
1742       break;
1743     case EraseOp:
1744       executeEraseOp(rewriter);
1745       break;
1746     case Finalize:
1747       LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n\n");
1748       return;
1749     case GetAttribute:
1750       executeGetAttribute();
1751       break;
1752     case GetAttributeType:
1753       executeGetAttributeType();
1754       break;
1755     case GetDefiningOp:
1756       executeGetDefiningOp();
1757       break;
1758     case GetOperand0:
1759     case GetOperand1:
1760     case GetOperand2:
1761     case GetOperand3: {
1762       unsigned index = opCode - GetOperand0;
1763       LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n");
1764       executeGetOperand(index);
1765       break;
1766     }
1767     case GetOperandN:
1768       LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n");
1769       executeGetOperand(read<uint32_t>());
1770       break;
1771     case GetOperands:
1772       executeGetOperands();
1773       break;
1774     case GetResult0:
1775     case GetResult1:
1776     case GetResult2:
1777     case GetResult3: {
1778       unsigned index = opCode - GetResult0;
1779       LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n");
1780       executeGetResult(index);
1781       break;
1782     }
1783     case GetResultN:
1784       LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n");
1785       executeGetResult(read<uint32_t>());
1786       break;
1787     case GetResults:
1788       executeGetResults();
1789       break;
1790     case GetValueType:
1791       executeGetValueType();
1792       break;
1793     case GetValueRangeTypes:
1794       executeGetValueRangeTypes();
1795       break;
1796     case IsNotNull:
1797       executeIsNotNull();
1798       break;
1799     case RecordMatch:
1800       assert(matches &&
1801              "expected matches to be provided when executing the matcher");
1802       executeRecordMatch(rewriter, *matches);
1803       break;
1804     case ReplaceOp:
1805       executeReplaceOp(rewriter);
1806       break;
1807     case SwitchAttribute:
1808       executeSwitchAttribute();
1809       break;
1810     case SwitchOperandCount:
1811       executeSwitchOperandCount();
1812       break;
1813     case SwitchOperationName:
1814       executeSwitchOperationName();
1815       break;
1816     case SwitchResultCount:
1817       executeSwitchResultCount();
1818       break;
1819     case SwitchType:
1820       executeSwitchType();
1821       break;
1822     case SwitchTypes:
1823       executeSwitchTypes();
1824       break;
1825     }
1826     LLVM_DEBUG(llvm::dbgs() << "\n");
1827   }
1828 }
1829 
1830 /// Run the pattern matcher on the given root operation, collecting the matched
1831 /// patterns in `matches`.
match(Operation * op,PatternRewriter & rewriter,SmallVectorImpl<MatchResult> & matches,PDLByteCodeMutableState & state) const1832 void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
1833                         SmallVectorImpl<MatchResult> &matches,
1834                         PDLByteCodeMutableState &state) const {
1835   // The first memory slot is always the root operation.
1836   state.memory[0] = op;
1837 
1838   // The matcher function always starts at code address 0.
1839   ByteCodeExecutor executor(
1840       matcherByteCode.data(), state.memory, state.typeRangeMemory,
1841       state.allocatedTypeRangeMemory, state.valueRangeMemory,
1842       state.allocatedValueRangeMemory, uniquedData, matcherByteCode,
1843       state.currentPatternBenefits, patterns, constraintFunctions,
1844       rewriteFunctions);
1845   executor.execute(rewriter, &matches);
1846 
1847   // Order the found matches by benefit.
1848   std::stable_sort(matches.begin(), matches.end(),
1849                    [](const MatchResult &lhs, const MatchResult &rhs) {
1850                      return lhs.benefit > rhs.benefit;
1851                    });
1852 }
1853 
1854 /// Run the rewriter of the given pattern on the root operation `op`.
rewrite(PatternRewriter & rewriter,const MatchResult & match,PDLByteCodeMutableState & state) const1855 void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match,
1856                           PDLByteCodeMutableState &state) const {
1857   // The arguments of the rewrite function are stored at the start of the
1858   // memory buffer.
1859   llvm::copy(match.values, state.memory.begin());
1860 
1861   ByteCodeExecutor executor(
1862       &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
1863       state.typeRangeMemory, state.allocatedTypeRangeMemory,
1864       state.valueRangeMemory, state.allocatedValueRangeMemory, uniquedData,
1865       rewriterByteCode, state.currentPatternBenefits, patterns,
1866       constraintFunctions, rewriteFunctions);
1867   executor.execute(rewriter, /*matches=*/nullptr, match.location);
1868 }
1869