1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements MLIR to byte-code generation and the interpreter.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "ByteCode.h"
14 #include "mlir/Analysis/Liveness.h"
15 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
16 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/RegionGraphTraits.h"
19 #include "llvm/ADT/IntervalMap.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23 
24 #define DEBUG_TYPE "pdl-bytecode"
25 
26 using namespace mlir;
27 using namespace mlir::detail;
28 
29 //===----------------------------------------------------------------------===//
30 // PDLByteCodePattern
31 //===----------------------------------------------------------------------===//
32 
create(pdl_interp::RecordMatchOp matchOp,ByteCodeAddr rewriterAddr)33 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
34                                               ByteCodeAddr rewriterAddr) {
35   SmallVector<StringRef, 8> generatedOps;
36   if (ArrayAttr generatedOpsAttr = matchOp.generatedOpsAttr())
37     generatedOps =
38         llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
39 
40   PatternBenefit benefit = matchOp.benefit();
41   MLIRContext *ctx = matchOp.getContext();
42 
43   // Check to see if this is pattern matches a specific operation type.
44   if (Optional<StringRef> rootKind = matchOp.rootKind())
45     return PDLByteCodePattern(rewriterAddr, *rootKind, generatedOps, benefit,
46                               ctx);
47   return PDLByteCodePattern(rewriterAddr, generatedOps, benefit, ctx,
48                             MatchAnyOpTypeTag());
49 }
50 
51 //===----------------------------------------------------------------------===//
52 // PDLByteCodeMutableState
53 //===----------------------------------------------------------------------===//
54 
55 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
56 /// to the position of the pattern within the range returned by
57 /// `PDLByteCode::getPatterns`.
updatePatternBenefit(unsigned patternIndex,PatternBenefit benefit)58 void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
59                                                    PatternBenefit benefit) {
60   currentPatternBenefits[patternIndex] = benefit;
61 }
62 
63 //===----------------------------------------------------------------------===//
64 // Bytecode OpCodes
65 //===----------------------------------------------------------------------===//
66 
67 namespace {
68 enum OpCode : ByteCodeField {
69   /// Apply an externally registered constraint.
70   ApplyConstraint,
71   /// Apply an externally registered rewrite.
72   ApplyRewrite,
73   /// Check if two generic values are equal.
74   AreEqual,
75   /// Unconditional branch.
76   Branch,
77   /// Compare the operand count of an operation with a constant.
78   CheckOperandCount,
79   /// Compare the name of an operation with a constant.
80   CheckOperationName,
81   /// Compare the result count of an operation with a constant.
82   CheckResultCount,
83   /// Invoke a native creation method.
84   CreateNative,
85   /// Create an operation.
86   CreateOperation,
87   /// Erase an operation.
88   EraseOp,
89   /// Terminate a matcher or rewrite sequence.
90   Finalize,
91   /// Get a specific attribute of an operation.
92   GetAttribute,
93   /// Get the type of an attribute.
94   GetAttributeType,
95   /// Get the defining operation of a value.
96   GetDefiningOp,
97   /// Get a specific operand of an operation.
98   GetOperand0,
99   GetOperand1,
100   GetOperand2,
101   GetOperand3,
102   GetOperandN,
103   /// Get a specific result of an operation.
104   GetResult0,
105   GetResult1,
106   GetResult2,
107   GetResult3,
108   GetResultN,
109   /// Get the type of a value.
110   GetValueType,
111   /// Check if a generic value is not null.
112   IsNotNull,
113   /// Record a successful pattern match.
114   RecordMatch,
115   /// Replace an operation.
116   ReplaceOp,
117   /// Compare an attribute with a set of constants.
118   SwitchAttribute,
119   /// Compare the operand count of an operation with a set of constants.
120   SwitchOperandCount,
121   /// Compare the name of an operation with a set of constants.
122   SwitchOperationName,
123   /// Compare the result count of an operation with a set of constants.
124   SwitchResultCount,
125   /// Compare a type with a set of constants.
126   SwitchType,
127 };
128 
129 enum class PDLValueKind { Attribute, Operation, Type, Value };
130 } // end anonymous namespace
131 
132 //===----------------------------------------------------------------------===//
133 // ByteCode Generation
134 //===----------------------------------------------------------------------===//
135 
136 //===----------------------------------------------------------------------===//
137 // Generator
138 
139 namespace {
140 struct ByteCodeWriter;
141 
142 /// This class represents the main generator for the pattern bytecode.
143 class Generator {
144 public:
Generator(MLIRContext * ctx,std::vector<const void * > & uniquedData,SmallVectorImpl<ByteCodeField> & matcherByteCode,SmallVectorImpl<ByteCodeField> & rewriterByteCode,SmallVectorImpl<PDLByteCodePattern> & patterns,ByteCodeField & maxValueMemoryIndex,llvm::StringMap<PDLConstraintFunction> & constraintFns,llvm::StringMap<PDLCreateFunction> & createFns,llvm::StringMap<PDLRewriteFunction> & rewriteFns)145   Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
146             SmallVectorImpl<ByteCodeField> &matcherByteCode,
147             SmallVectorImpl<ByteCodeField> &rewriterByteCode,
148             SmallVectorImpl<PDLByteCodePattern> &patterns,
149             ByteCodeField &maxValueMemoryIndex,
150             llvm::StringMap<PDLConstraintFunction> &constraintFns,
151             llvm::StringMap<PDLCreateFunction> &createFns,
152             llvm::StringMap<PDLRewriteFunction> &rewriteFns)
153       : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
154         rewriterByteCode(rewriterByteCode), patterns(patterns),
155         maxValueMemoryIndex(maxValueMemoryIndex) {
156     for (auto it : llvm::enumerate(constraintFns))
157       constraintToMemIndex.try_emplace(it.value().first(), it.index());
158     for (auto it : llvm::enumerate(createFns))
159       nativeCreateToMemIndex.try_emplace(it.value().first(), it.index());
160     for (auto it : llvm::enumerate(rewriteFns))
161       externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
162   }
163 
164   /// Generate the bytecode for the given PDL interpreter module.
165   void generate(ModuleOp module);
166 
167   /// Return the memory index to use for the given value.
getMemIndex(Value value)168   ByteCodeField &getMemIndex(Value value) {
169     assert(valueToMemIndex.count(value) &&
170            "expected memory index to be assigned");
171     return valueToMemIndex[value];
172   }
173 
174   /// Return an index to use when referring to the given data that is uniqued in
175   /// the MLIR context.
176   template <typename T>
177   std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
getMemIndex(T val)178   getMemIndex(T val) {
179     const void *opaqueVal = val.getAsOpaquePointer();
180 
181     // Get or insert a reference to this value.
182     auto it = uniquedDataToMemIndex.try_emplace(
183         opaqueVal, maxValueMemoryIndex + uniquedData.size());
184     if (it.second)
185       uniquedData.push_back(opaqueVal);
186     return it.first->second;
187   }
188 
189 private:
190   /// Allocate memory indices for the results of operations within the matcher
191   /// and rewriters.
192   void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule);
193 
194   /// Generate the bytecode for the given operation.
195   void generate(Operation *op, ByteCodeWriter &writer);
196   void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
197   void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
198   void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
199   void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
200   void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
201   void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
202   void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
203   void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
204   void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
205   void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
206   void generate(pdl_interp::CreateNativeOp op, ByteCodeWriter &writer);
207   void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
208   void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
209   void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
210   void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
211   void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
212   void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
213   void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
214   void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
215   void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
216   void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
217   void generate(pdl_interp::InferredTypeOp op, ByteCodeWriter &writer);
218   void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
219   void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
220   void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
221   void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
222   void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
223   void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
224   void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
225   void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
226 
227   /// Mapping from value to its corresponding memory index.
228   DenseMap<Value, ByteCodeField> valueToMemIndex;
229 
230   /// Mapping from the name of an externally registered rewrite to its index in
231   /// the bytecode registry.
232   llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
233 
234   /// Mapping from the name of an externally registered constraint to its index
235   /// in the bytecode registry.
236   llvm::StringMap<ByteCodeField> constraintToMemIndex;
237 
238   /// Mapping from the name of an externally registered creation method to its
239   /// index in the bytecode registry.
240   llvm::StringMap<ByteCodeField> nativeCreateToMemIndex;
241 
242   /// Mapping from rewriter function name to the bytecode address of the
243   /// rewriter function in byte.
244   llvm::StringMap<ByteCodeAddr> rewriterToAddr;
245 
246   /// Mapping from a uniqued storage object to its memory index within
247   /// `uniquedData`.
248   DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
249 
250   /// The current MLIR context.
251   MLIRContext *ctx;
252 
253   /// Data of the ByteCode class to be populated.
254   std::vector<const void *> &uniquedData;
255   SmallVectorImpl<ByteCodeField> &matcherByteCode;
256   SmallVectorImpl<ByteCodeField> &rewriterByteCode;
257   SmallVectorImpl<PDLByteCodePattern> &patterns;
258   ByteCodeField &maxValueMemoryIndex;
259 };
260 
261 /// This class provides utilities for writing a bytecode stream.
262 struct ByteCodeWriter {
ByteCodeWriter__anon699baea10211::ByteCodeWriter263   ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
264       : bytecode(bytecode), generator(generator) {}
265 
266   /// Append a field to the bytecode.
append__anon699baea10211::ByteCodeWriter267   void append(ByteCodeField field) { bytecode.push_back(field); }
append__anon699baea10211::ByteCodeWriter268   void append(OpCode opCode) { bytecode.push_back(opCode); }
269 
270   /// Append an address to the bytecode.
append__anon699baea10211::ByteCodeWriter271   void append(ByteCodeAddr field) {
272     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
273                   "unexpected ByteCode address size");
274 
275     ByteCodeField fieldParts[2];
276     std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
277     bytecode.append({fieldParts[0], fieldParts[1]});
278   }
279 
280   /// Append a successor range to the bytecode, the exact address will need to
281   /// be resolved later.
append__anon699baea10211::ByteCodeWriter282   void append(SuccessorRange successors) {
283     // Add back references to the any successors so that the address can be
284     // resolved later.
285     for (Block *successor : successors) {
286       unresolvedSuccessorRefs[successor].push_back(bytecode.size());
287       append(ByteCodeAddr(0));
288     }
289   }
290 
291   /// Append a range of values that will be read as generic PDLValues.
appendPDLValueList__anon699baea10211::ByteCodeWriter292   void appendPDLValueList(OperandRange values) {
293     bytecode.push_back(values.size());
294     for (Value value : values) {
295       // Append the type of the value in addition to the value itself.
296       PDLValueKind kind =
297           TypeSwitch<Type, PDLValueKind>(value.getType())
298               .Case<pdl::AttributeType>(
299                   [](Type) { return PDLValueKind::Attribute; })
300               .Case<pdl::OperationType>(
301                   [](Type) { return PDLValueKind::Operation; })
302               .Case<pdl::TypeType>([](Type) { return PDLValueKind::Type; })
303               .Case<pdl::ValueType>([](Type) { return PDLValueKind::Value; });
304       bytecode.push_back(static_cast<ByteCodeField>(kind));
305       append(value);
306     }
307   }
308 
309   /// Check if the given class `T` has an iterator type.
310   template <typename T, typename... Args>
311   using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
312 
313   /// Append a value that will be stored in a memory slot and not inline within
314   /// the bytecode.
315   template <typename T>
316   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
317                    std::is_pointer<T>::value>
append__anon699baea10211::ByteCodeWriter318   append(T value) {
319     bytecode.push_back(generator.getMemIndex(value));
320   }
321 
322   /// Append a range of values.
323   template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
324   std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
append__anon699baea10211::ByteCodeWriter325   append(T range) {
326     bytecode.push_back(llvm::size(range));
327     for (auto it : range)
328       append(it);
329   }
330 
331   /// Append a variadic number of fields to the bytecode.
332   template <typename FieldTy, typename Field2Ty, typename... FieldTys>
append__anon699baea10211::ByteCodeWriter333   void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
334     append(field);
335     append(field2, fields...);
336   }
337 
338   /// Successor references in the bytecode that have yet to be resolved.
339   DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
340 
341   /// The underlying bytecode buffer.
342   SmallVectorImpl<ByteCodeField> &bytecode;
343 
344   /// The main generator producing PDL.
345   Generator &generator;
346 };
347 } // end anonymous namespace
348 
generate(ModuleOp module)349 void Generator::generate(ModuleOp module) {
350   FuncOp matcherFunc = module.lookupSymbol<FuncOp>(
351       pdl_interp::PDLInterpDialect::getMatcherFunctionName());
352   ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
353       pdl_interp::PDLInterpDialect::getRewriterModuleName());
354   assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
355 
356   // Allocate memory indices for the results of operations within the matcher
357   // and rewriters.
358   allocateMemoryIndices(matcherFunc, rewriterModule);
359 
360   // Generate code for the rewriter functions.
361   ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
362   for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
363     rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
364     for (Operation &op : rewriterFunc.getOps())
365       generate(&op, rewriterByteCodeWriter);
366   }
367   assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
368          "unexpected branches in rewriter function");
369 
370   // Generate code for the matcher function.
371   DenseMap<Block *, ByteCodeAddr> blockToAddr;
372   llvm::ReversePostOrderTraversal<Region *> rpot(&matcherFunc.getBody());
373   ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
374   for (Block *block : rpot) {
375     // Keep track of where this block begins within the matcher function.
376     blockToAddr.try_emplace(block, matcherByteCode.size());
377     for (Operation &op : *block)
378       generate(&op, matcherByteCodeWriter);
379   }
380 
381   // Resolve successor references in the matcher.
382   for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
383     ByteCodeAddr addr = blockToAddr[it.first];
384     for (unsigned offsetToFix : it.second)
385       std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
386   }
387 }
388 
allocateMemoryIndices(FuncOp matcherFunc,ModuleOp rewriterModule)389 void Generator::allocateMemoryIndices(FuncOp matcherFunc,
390                                       ModuleOp rewriterModule) {
391   // Rewriters use simplistic allocation scheme that simply assigns an index to
392   // each result.
393   for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
394     ByteCodeField index = 0;
395     for (BlockArgument arg : rewriterFunc.getArguments())
396       valueToMemIndex.try_emplace(arg, index++);
397     rewriterFunc.getBody().walk([&](Operation *op) {
398       for (Value result : op->getResults())
399         valueToMemIndex.try_emplace(result, index++);
400     });
401     if (index > maxValueMemoryIndex)
402       maxValueMemoryIndex = index;
403   }
404 
405   // The matcher function uses a more sophisticated numbering that tries to
406   // minimize the number of memory indices assigned. This is done by determining
407   // a live range of the values within the matcher, then the allocation is just
408   // finding the minimal number of overlapping live ranges. This is essentially
409   // a simplified form of register allocation where we don't necessarily have a
410   // limited number of registers, but we still want to minimize the number used.
411   DenseMap<Operation *, ByteCodeField> opToIndex;
412   matcherFunc.getBody().walk([&](Operation *op) {
413     opToIndex.insert(std::make_pair(op, opToIndex.size()));
414   });
415 
416   // Liveness info for each of the defs within the matcher.
417   using LivenessSet = llvm::IntervalMap<ByteCodeField, char, 16>;
418   LivenessSet::Allocator allocator;
419   DenseMap<Value, LivenessSet> valueDefRanges;
420 
421   // Assign the root operation being matched to slot 0.
422   BlockArgument rootOpArg = matcherFunc.getArgument(0);
423   valueToMemIndex[rootOpArg] = 0;
424 
425   // Walk each of the blocks, computing the def interval that the value is used.
426   Liveness matcherLiveness(matcherFunc);
427   for (Block &block : matcherFunc.getBody()) {
428     const LivenessBlockInfo *info = matcherLiveness.getLiveness(&block);
429     assert(info && "expected liveness info for block");
430     auto processValue = [&](Value value, Operation *firstUseOrDef) {
431       // We don't need to process the root op argument, this value is always
432       // assigned to the first memory slot.
433       if (value == rootOpArg)
434         return;
435 
436       // Set indices for the range of this block that the value is used.
437       auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
438       defRangeIt->second.insert(
439           opToIndex[firstUseOrDef],
440           opToIndex[info->getEndOperation(value, firstUseOrDef)],
441           /*dummyValue*/ 0);
442     };
443 
444     // Process the live-ins of this block.
445     for (Value liveIn : info->in())
446       processValue(liveIn, &block.front());
447 
448     // Process any new defs within this block.
449     for (Operation &op : block)
450       for (Value result : op.getResults())
451         processValue(result, &op);
452   }
453 
454   // Greedily allocate memory slots using the computed def live ranges.
455   std::vector<LivenessSet> allocatedIndices;
456   for (auto &defIt : valueDefRanges) {
457     ByteCodeField &memIndex = valueToMemIndex[defIt.first];
458     LivenessSet &defSet = defIt.second;
459 
460     // Try to allocate to an existing index.
461     for (auto existingIndexIt : llvm::enumerate(allocatedIndices)) {
462       LivenessSet &existingIndex = existingIndexIt.value();
463       llvm::IntervalMapOverlaps<LivenessSet, LivenessSet> overlaps(
464           defIt.second, existingIndex);
465       if (overlaps.valid())
466         continue;
467       // Union the range of the def within the existing index.
468       for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it)
469         existingIndex.insert(it.start(), it.stop(), /*dummyValue*/ 0);
470       memIndex = existingIndexIt.index() + 1;
471     }
472 
473     // If no existing index could be used, add a new one.
474     if (memIndex == 0) {
475       allocatedIndices.emplace_back(allocator);
476       for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it)
477         allocatedIndices.back().insert(it.start(), it.stop(), /*dummyValue*/ 0);
478       memIndex = allocatedIndices.size();
479     }
480   }
481 
482   // Update the max number of indices.
483   ByteCodeField numMatcherIndices = allocatedIndices.size() + 1;
484   if (numMatcherIndices > maxValueMemoryIndex)
485     maxValueMemoryIndex = numMatcherIndices;
486 }
487 
generate(Operation * op,ByteCodeWriter & writer)488 void Generator::generate(Operation *op, ByteCodeWriter &writer) {
489   TypeSwitch<Operation *>(op)
490       .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
491             pdl_interp::AreEqualOp, pdl_interp::BranchOp,
492             pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
493             pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
494             pdl_interp::CheckTypeOp, pdl_interp::CreateAttributeOp,
495             pdl_interp::CreateNativeOp, pdl_interp::CreateOperationOp,
496             pdl_interp::CreateTypeOp, pdl_interp::EraseOp,
497             pdl_interp::FinalizeOp, pdl_interp::GetAttributeOp,
498             pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
499             pdl_interp::GetOperandOp, pdl_interp::GetResultOp,
500             pdl_interp::GetValueTypeOp, pdl_interp::InferredTypeOp,
501             pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
502             pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
503             pdl_interp::SwitchTypeOp, pdl_interp::SwitchOperandCountOp,
504             pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>(
505           [&](auto interpOp) { this->generate(interpOp, writer); })
506       .Default([](Operation *) {
507         llvm_unreachable("unknown `pdl_interp` operation");
508       });
509 }
510 
generate(pdl_interp::ApplyConstraintOp op,ByteCodeWriter & writer)511 void Generator::generate(pdl_interp::ApplyConstraintOp op,
512                          ByteCodeWriter &writer) {
513   assert(constraintToMemIndex.count(op.name()) &&
514          "expected index for constraint function");
515   writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.name()],
516                 op.constParamsAttr());
517   writer.appendPDLValueList(op.args());
518   writer.append(op.getSuccessors());
519 }
generate(pdl_interp::ApplyRewriteOp op,ByteCodeWriter & writer)520 void Generator::generate(pdl_interp::ApplyRewriteOp op,
521                          ByteCodeWriter &writer) {
522   assert(externalRewriterToMemIndex.count(op.name()) &&
523          "expected index for rewrite function");
524   writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()],
525                 op.constParamsAttr(), op.root());
526   writer.appendPDLValueList(op.args());
527 }
generate(pdl_interp::AreEqualOp op,ByteCodeWriter & writer)528 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
529   writer.append(OpCode::AreEqual, op.lhs(), op.rhs(), op.getSuccessors());
530 }
generate(pdl_interp::BranchOp op,ByteCodeWriter & writer)531 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
532   writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
533 }
generate(pdl_interp::CheckAttributeOp op,ByteCodeWriter & writer)534 void Generator::generate(pdl_interp::CheckAttributeOp op,
535                          ByteCodeWriter &writer) {
536   writer.append(OpCode::AreEqual, op.attribute(), op.constantValue(),
537                 op.getSuccessors());
538 }
generate(pdl_interp::CheckOperandCountOp op,ByteCodeWriter & writer)539 void Generator::generate(pdl_interp::CheckOperandCountOp op,
540                          ByteCodeWriter &writer) {
541   writer.append(OpCode::CheckOperandCount, op.operation(), op.count(),
542                 op.getSuccessors());
543 }
generate(pdl_interp::CheckOperationNameOp op,ByteCodeWriter & writer)544 void Generator::generate(pdl_interp::CheckOperationNameOp op,
545                          ByteCodeWriter &writer) {
546   writer.append(OpCode::CheckOperationName, op.operation(),
547                 OperationName(op.name(), ctx), op.getSuccessors());
548 }
generate(pdl_interp::CheckResultCountOp op,ByteCodeWriter & writer)549 void Generator::generate(pdl_interp::CheckResultCountOp op,
550                          ByteCodeWriter &writer) {
551   writer.append(OpCode::CheckResultCount, op.operation(), op.count(),
552                 op.getSuccessors());
553 }
generate(pdl_interp::CheckTypeOp op,ByteCodeWriter & writer)554 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
555   writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors());
556 }
generate(pdl_interp::CreateAttributeOp op,ByteCodeWriter & writer)557 void Generator::generate(pdl_interp::CreateAttributeOp op,
558                          ByteCodeWriter &writer) {
559   // Simply repoint the memory index of the result to the constant.
560   getMemIndex(op.attribute()) = getMemIndex(op.value());
561 }
generate(pdl_interp::CreateNativeOp op,ByteCodeWriter & writer)562 void Generator::generate(pdl_interp::CreateNativeOp op,
563                          ByteCodeWriter &writer) {
564   assert(nativeCreateToMemIndex.count(op.name()) &&
565          "expected index for creation function");
566   writer.append(OpCode::CreateNative, nativeCreateToMemIndex[op.name()],
567                 op.result(), op.constParamsAttr());
568   writer.appendPDLValueList(op.args());
569 }
generate(pdl_interp::CreateOperationOp op,ByteCodeWriter & writer)570 void Generator::generate(pdl_interp::CreateOperationOp op,
571                          ByteCodeWriter &writer) {
572   writer.append(OpCode::CreateOperation, op.operation(),
573                 OperationName(op.name(), ctx), op.operands());
574 
575   // Add the attributes.
576   OperandRange attributes = op.attributes();
577   writer.append(static_cast<ByteCodeField>(attributes.size()));
578   for (auto it : llvm::zip(op.attributeNames(), op.attributes())) {
579     writer.append(
580         Identifier::get(std::get<0>(it).cast<StringAttr>().getValue(), ctx),
581         std::get<1>(it));
582   }
583   writer.append(op.types());
584 }
generate(pdl_interp::CreateTypeOp op,ByteCodeWriter & writer)585 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
586   // Simply repoint the memory index of the result to the constant.
587   getMemIndex(op.result()) = getMemIndex(op.value());
588 }
generate(pdl_interp::EraseOp op,ByteCodeWriter & writer)589 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
590   writer.append(OpCode::EraseOp, op.operation());
591 }
generate(pdl_interp::FinalizeOp op,ByteCodeWriter & writer)592 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
593   writer.append(OpCode::Finalize);
594 }
generate(pdl_interp::GetAttributeOp op,ByteCodeWriter & writer)595 void Generator::generate(pdl_interp::GetAttributeOp op,
596                          ByteCodeWriter &writer) {
597   writer.append(OpCode::GetAttribute, op.attribute(), op.operation(),
598                 Identifier::get(op.name(), ctx));
599 }
generate(pdl_interp::GetAttributeTypeOp op,ByteCodeWriter & writer)600 void Generator::generate(pdl_interp::GetAttributeTypeOp op,
601                          ByteCodeWriter &writer) {
602   writer.append(OpCode::GetAttributeType, op.result(), op.value());
603 }
generate(pdl_interp::GetDefiningOpOp op,ByteCodeWriter & writer)604 void Generator::generate(pdl_interp::GetDefiningOpOp op,
605                          ByteCodeWriter &writer) {
606   writer.append(OpCode::GetDefiningOp, op.operation(), op.value());
607 }
generate(pdl_interp::GetOperandOp op,ByteCodeWriter & writer)608 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
609   uint32_t index = op.index();
610   if (index < 4)
611     writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
612   else
613     writer.append(OpCode::GetOperandN, index);
614   writer.append(op.operation(), op.value());
615 }
generate(pdl_interp::GetResultOp op,ByteCodeWriter & writer)616 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
617   uint32_t index = op.index();
618   if (index < 4)
619     writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
620   else
621     writer.append(OpCode::GetResultN, index);
622   writer.append(op.operation(), op.value());
623 }
generate(pdl_interp::GetValueTypeOp op,ByteCodeWriter & writer)624 void Generator::generate(pdl_interp::GetValueTypeOp op,
625                          ByteCodeWriter &writer) {
626   writer.append(OpCode::GetValueType, op.result(), op.value());
627 }
generate(pdl_interp::InferredTypeOp op,ByteCodeWriter & writer)628 void Generator::generate(pdl_interp::InferredTypeOp op,
629                          ByteCodeWriter &writer) {
630   // InferType maps to a null type as a marker for inferring a result type.
631   getMemIndex(op.type()) = getMemIndex(Type());
632 }
generate(pdl_interp::IsNotNullOp op,ByteCodeWriter & writer)633 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
634   writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors());
635 }
generate(pdl_interp::RecordMatchOp op,ByteCodeWriter & writer)636 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
637   ByteCodeField patternIndex = patterns.size();
638   patterns.emplace_back(PDLByteCodePattern::create(
639       op, rewriterToAddr[op.rewriter().getLeafReference()]));
640   writer.append(OpCode::RecordMatch, patternIndex,
641                 SuccessorRange(op.getOperation()), op.matchedOps(),
642                 op.inputs());
643 }
generate(pdl_interp::ReplaceOp op,ByteCodeWriter & writer)644 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
645   writer.append(OpCode::ReplaceOp, op.operation(), op.replValues());
646 }
generate(pdl_interp::SwitchAttributeOp op,ByteCodeWriter & writer)647 void Generator::generate(pdl_interp::SwitchAttributeOp op,
648                          ByteCodeWriter &writer) {
649   writer.append(OpCode::SwitchAttribute, op.attribute(), op.caseValuesAttr(),
650                 op.getSuccessors());
651 }
generate(pdl_interp::SwitchOperandCountOp op,ByteCodeWriter & writer)652 void Generator::generate(pdl_interp::SwitchOperandCountOp op,
653                          ByteCodeWriter &writer) {
654   writer.append(OpCode::SwitchOperandCount, op.operation(), op.caseValuesAttr(),
655                 op.getSuccessors());
656 }
generate(pdl_interp::SwitchOperationNameOp op,ByteCodeWriter & writer)657 void Generator::generate(pdl_interp::SwitchOperationNameOp op,
658                          ByteCodeWriter &writer) {
659   auto cases = llvm::map_range(op.caseValuesAttr(), [&](Attribute attr) {
660     return OperationName(attr.cast<StringAttr>().getValue(), ctx);
661   });
662   writer.append(OpCode::SwitchOperationName, op.operation(), cases,
663                 op.getSuccessors());
664 }
generate(pdl_interp::SwitchResultCountOp op,ByteCodeWriter & writer)665 void Generator::generate(pdl_interp::SwitchResultCountOp op,
666                          ByteCodeWriter &writer) {
667   writer.append(OpCode::SwitchResultCount, op.operation(), op.caseValuesAttr(),
668                 op.getSuccessors());
669 }
generate(pdl_interp::SwitchTypeOp op,ByteCodeWriter & writer)670 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
671   writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(),
672                 op.getSuccessors());
673 }
674 
675 //===----------------------------------------------------------------------===//
676 // PDLByteCode
677 //===----------------------------------------------------------------------===//
678 
PDLByteCode(ModuleOp module,llvm::StringMap<PDLConstraintFunction> constraintFns,llvm::StringMap<PDLCreateFunction> createFns,llvm::StringMap<PDLRewriteFunction> rewriteFns)679 PDLByteCode::PDLByteCode(ModuleOp module,
680                          llvm::StringMap<PDLConstraintFunction> constraintFns,
681                          llvm::StringMap<PDLCreateFunction> createFns,
682                          llvm::StringMap<PDLRewriteFunction> rewriteFns) {
683   Generator generator(module.getContext(), uniquedData, matcherByteCode,
684                       rewriterByteCode, patterns, maxValueMemoryIndex,
685                       constraintFns, createFns, rewriteFns);
686   generator.generate(module);
687 
688   // Initialize the external functions.
689   for (auto &it : constraintFns)
690     constraintFunctions.push_back(std::move(it.second));
691   for (auto &it : createFns)
692     createFunctions.push_back(std::move(it.second));
693   for (auto &it : rewriteFns)
694     rewriteFunctions.push_back(std::move(it.second));
695 }
696 
697 /// Initialize the given state such that it can be used to execute the current
698 /// bytecode.
initializeMutableState(PDLByteCodeMutableState & state) const699 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
700   state.memory.resize(maxValueMemoryIndex, nullptr);
701   state.currentPatternBenefits.reserve(patterns.size());
702   for (const PDLByteCodePattern &pattern : patterns)
703     state.currentPatternBenefits.push_back(pattern.getBenefit());
704 }
705 
706 //===----------------------------------------------------------------------===//
707 // ByteCode Execution
708 
709 namespace {
710 /// This class provides support for executing a bytecode stream.
711 class ByteCodeExecutor {
712 public:
ByteCodeExecutor(const ByteCodeField * curCodeIt,MutableArrayRef<const void * > memory,ArrayRef<const void * > uniquedMemory,ArrayRef<ByteCodeField> code,ArrayRef<PatternBenefit> currentPatternBenefits,ArrayRef<PDLByteCodePattern> patterns,ArrayRef<PDLConstraintFunction> constraintFunctions,ArrayRef<PDLCreateFunction> createFunctions,ArrayRef<PDLRewriteFunction> rewriteFunctions)713   ByteCodeExecutor(const ByteCodeField *curCodeIt,
714                    MutableArrayRef<const void *> memory,
715                    ArrayRef<const void *> uniquedMemory,
716                    ArrayRef<ByteCodeField> code,
717                    ArrayRef<PatternBenefit> currentPatternBenefits,
718                    ArrayRef<PDLByteCodePattern> patterns,
719                    ArrayRef<PDLConstraintFunction> constraintFunctions,
720                    ArrayRef<PDLCreateFunction> createFunctions,
721                    ArrayRef<PDLRewriteFunction> rewriteFunctions)
722       : curCodeIt(curCodeIt), memory(memory), uniquedMemory(uniquedMemory),
723         code(code), currentPatternBenefits(currentPatternBenefits),
724         patterns(patterns), constraintFunctions(constraintFunctions),
725         createFunctions(createFunctions), rewriteFunctions(rewriteFunctions) {}
726 
727   /// Start executing the code at the current bytecode index. `matches` is an
728   /// optional field provided when this function is executed in a matching
729   /// context.
730   void execute(PatternRewriter &rewriter,
731                SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
732                Optional<Location> mainRewriteLoc = {});
733 
734 private:
735   /// Read a value from the bytecode buffer, optionally skipping a certain
736   /// number of prefix values. These methods always update the buffer to point
737   /// to the next field after the read data.
738   template <typename T = ByteCodeField>
read(size_t skipN=0)739   T read(size_t skipN = 0) {
740     curCodeIt += skipN;
741     return readImpl<T>();
742   }
read(size_t skipN=0)743   ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
744 
745   /// Read a list of values from the bytecode buffer.
746   template <typename ValueT, typename T>
readList(SmallVectorImpl<T> & list)747   void readList(SmallVectorImpl<T> &list) {
748     list.clear();
749     for (unsigned i = 0, e = read(); i != e; ++i)
750       list.push_back(read<ValueT>());
751   }
752 
753   /// Jump to a specific successor based on a predicate value.
selectJump(bool isTrue)754   void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
755   /// Jump to a specific successor based on a destination index.
selectJump(size_t destIndex)756   void selectJump(size_t destIndex) {
757     curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
758   }
759 
760   /// Handle a switch operation with the provided value and cases.
761   template <typename T, typename RangeT>
handleSwitch(const T & value,RangeT && cases)762   void handleSwitch(const T &value, RangeT &&cases) {
763     LLVM_DEBUG({
764       llvm::dbgs() << "  * Value: " << value << "\n"
765                    << "  * Cases: ";
766       llvm::interleaveComma(cases, llvm::dbgs());
767       llvm::dbgs() << "\n\n";
768     });
769 
770     // Check to see if the attribute value is within the case list. Jump to
771     // the correct successor index based on the result.
772     for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
773       if (*it == value)
774         return selectJump(size_t((it - cases.begin()) + 1));
775     selectJump(size_t(0));
776   }
777 
778   /// Internal implementation of reading various data types from the bytecode
779   /// stream.
780   template <typename T>
readFromMemory()781   const void *readFromMemory() {
782     size_t index = *curCodeIt++;
783 
784     // If this type is an SSA value, it can only be stored in non-const memory.
785     if (llvm::is_one_of<T, Operation *, Value>::value || index < memory.size())
786       return memory[index];
787 
788     // Otherwise, if this index is not inbounds it is uniqued.
789     return uniquedMemory[index - memory.size()];
790   }
791   template <typename T>
readImpl()792   std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
793     return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
794   }
795   template <typename T>
796   std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
797                    T>
readImpl()798   readImpl() {
799     return T(T::getFromOpaquePointer(readFromMemory<T>()));
800   }
801   template <typename T>
readImpl()802   std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
803     switch (static_cast<PDLValueKind>(read())) {
804     case PDLValueKind::Attribute:
805       return read<Attribute>();
806     case PDLValueKind::Operation:
807       return read<Operation *>();
808     case PDLValueKind::Type:
809       return read<Type>();
810     case PDLValueKind::Value:
811       return read<Value>();
812     }
813     llvm_unreachable("unhandled PDLValueKind");
814   }
815   template <typename T>
readImpl()816   std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
817     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
818                   "unexpected ByteCode address size");
819     ByteCodeAddr result;
820     std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
821     curCodeIt += 2;
822     return result;
823   }
824   template <typename T>
readImpl()825   std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
826     return *curCodeIt++;
827   }
828 
829   /// The underlying bytecode buffer.
830   const ByteCodeField *curCodeIt;
831 
832   /// The current execution memory.
833   MutableArrayRef<const void *> memory;
834 
835   /// References to ByteCode data necessary for execution.
836   ArrayRef<const void *> uniquedMemory;
837   ArrayRef<ByteCodeField> code;
838   ArrayRef<PatternBenefit> currentPatternBenefits;
839   ArrayRef<PDLByteCodePattern> patterns;
840   ArrayRef<PDLConstraintFunction> constraintFunctions;
841   ArrayRef<PDLCreateFunction> createFunctions;
842   ArrayRef<PDLRewriteFunction> rewriteFunctions;
843 };
844 } // end anonymous namespace
845 
execute(PatternRewriter & rewriter,SmallVectorImpl<PDLByteCode::MatchResult> * matches,Optional<Location> mainRewriteLoc)846 void ByteCodeExecutor::execute(
847     PatternRewriter &rewriter,
848     SmallVectorImpl<PDLByteCode::MatchResult> *matches,
849     Optional<Location> mainRewriteLoc) {
850   while (true) {
851     OpCode opCode = static_cast<OpCode>(read());
852     switch (opCode) {
853     case ApplyConstraint: {
854       LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
855       const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
856       ArrayAttr constParams = read<ArrayAttr>();
857       SmallVector<PDLValue, 16> args;
858       readList<PDLValue>(args);
859       LLVM_DEBUG({
860         llvm::dbgs() << "  * Arguments: ";
861         llvm::interleaveComma(args, llvm::dbgs());
862         llvm::dbgs() << "\n  * Parameters: " << constParams << "\n\n";
863       });
864 
865       // Invoke the constraint and jump to the proper destination.
866       selectJump(succeeded(constraintFn(args, constParams, rewriter)));
867       break;
868     }
869     case ApplyRewrite: {
870       LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
871       const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
872       ArrayAttr constParams = read<ArrayAttr>();
873       Operation *root = read<Operation *>();
874       SmallVector<PDLValue, 16> args;
875       readList<PDLValue>(args);
876 
877       LLVM_DEBUG({
878         llvm::dbgs() << "  * Root: " << *root << "\n"
879                      << "  * Arguments: ";
880         llvm::interleaveComma(args, llvm::dbgs());
881         llvm::dbgs() << "\n  * Parameters: " << constParams << "\n\n";
882       });
883       rewriteFn(root, args, constParams, rewriter);
884       break;
885     }
886     case AreEqual: {
887       LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
888       const void *lhs = read<const void *>();
889       const void *rhs = read<const void *>();
890 
891       LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
892       selectJump(lhs == rhs);
893       break;
894     }
895     case Branch: {
896       LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n\n");
897       curCodeIt = &code[read<ByteCodeAddr>()];
898       break;
899     }
900     case CheckOperandCount: {
901       LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
902       Operation *op = read<Operation *>();
903       uint32_t expectedCount = read<uint32_t>();
904 
905       LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumOperands() << "\n"
906                               << "  * Expected: " << expectedCount << "\n\n");
907       selectJump(op->getNumOperands() == expectedCount);
908       break;
909     }
910     case CheckOperationName: {
911       LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
912       Operation *op = read<Operation *>();
913       OperationName expectedName = read<OperationName>();
914 
915       LLVM_DEBUG(llvm::dbgs()
916                  << "  * Found: \"" << op->getName() << "\"\n"
917                  << "  * Expected: \"" << expectedName << "\"\n\n");
918       selectJump(op->getName() == expectedName);
919       break;
920     }
921     case CheckResultCount: {
922       LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
923       Operation *op = read<Operation *>();
924       uint32_t expectedCount = read<uint32_t>();
925 
926       LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumResults() << "\n"
927                               << "  * Expected: " << expectedCount << "\n\n");
928       selectJump(op->getNumResults() == expectedCount);
929       break;
930     }
931     case CreateNative: {
932       LLVM_DEBUG(llvm::dbgs() << "Executing CreateNative:\n");
933       const PDLCreateFunction &createFn = createFunctions[read()];
934       ByteCodeField resultIndex = read();
935       ArrayAttr constParams = read<ArrayAttr>();
936       SmallVector<PDLValue, 16> args;
937       readList<PDLValue>(args);
938 
939       LLVM_DEBUG({
940         llvm::dbgs() << "  * Arguments: ";
941         llvm::interleaveComma(args, llvm::dbgs());
942         llvm::dbgs() << "\n  * Parameters: " << constParams << "\n";
943       });
944 
945       PDLValue result = createFn(args, constParams, rewriter);
946       memory[resultIndex] = result.getAsOpaquePointer();
947 
948       LLVM_DEBUG(llvm::dbgs() << "  * Result: " << result << "\n\n");
949       break;
950     }
951     case CreateOperation: {
952       LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
953       assert(mainRewriteLoc && "expected rewrite loc to be provided when "
954                                "executing the rewriter bytecode");
955 
956       unsigned memIndex = read();
957       OperationState state(*mainRewriteLoc, read<OperationName>());
958       readList<Value>(state.operands);
959       for (unsigned i = 0, e = read(); i != e; ++i) {
960         Identifier name = read<Identifier>();
961         if (Attribute attr = read<Attribute>())
962           state.addAttribute(name, attr);
963       }
964 
965       bool hasInferredTypes = false;
966       for (unsigned i = 0, e = read(); i != e; ++i) {
967         Type resultType = read<Type>();
968         hasInferredTypes |= !resultType;
969         state.types.push_back(resultType);
970       }
971 
972       // Handle the case where the operation has inferred types.
973       if (hasInferredTypes) {
974         InferTypeOpInterface::Concept *concept =
975             state.name.getAbstractOperation()
976                 ->getInterface<InferTypeOpInterface>();
977 
978         // TODO: Handle failure.
979         SmallVector<Type, 2> inferredTypes;
980         if (failed(concept->inferReturnTypes(
981                 state.getContext(), state.location, state.operands,
982                 state.attributes.getDictionary(state.getContext()),
983                 state.regions, inferredTypes)))
984           return;
985 
986         for (unsigned i = 0, e = state.types.size(); i != e; ++i)
987           if (!state.types[i])
988             state.types[i] = inferredTypes[i];
989       }
990       Operation *resultOp = rewriter.createOperation(state);
991       memory[memIndex] = resultOp;
992 
993       LLVM_DEBUG({
994         llvm::dbgs() << "  * Attributes: "
995                      << state.attributes.getDictionary(state.getContext())
996                      << "\n  * Operands: ";
997         llvm::interleaveComma(state.operands, llvm::dbgs());
998         llvm::dbgs() << "\n  * Result Types: ";
999         llvm::interleaveComma(state.types, llvm::dbgs());
1000         llvm::dbgs() << "\n  * Result: " << *resultOp << "\n\n";
1001       });
1002       break;
1003     }
1004     case EraseOp: {
1005       LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
1006       Operation *op = read<Operation *>();
1007 
1008       LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n\n");
1009       rewriter.eraseOp(op);
1010       break;
1011     }
1012     case Finalize: {
1013       LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n\n");
1014       return;
1015     }
1016     case GetAttribute: {
1017       LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
1018       unsigned memIndex = read();
1019       Operation *op = read<Operation *>();
1020       Identifier attrName = read<Identifier>();
1021       Attribute attr = op->getAttr(attrName);
1022 
1023       LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1024                               << "  * Attribute: " << attrName << "\n"
1025                               << "  * Result: " << attr << "\n\n");
1026       memory[memIndex] = attr.getAsOpaquePointer();
1027       break;
1028     }
1029     case GetAttributeType: {
1030       LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
1031       unsigned memIndex = read();
1032       Attribute attr = read<Attribute>();
1033 
1034       LLVM_DEBUG(llvm::dbgs() << "  * Attribute: " << attr << "\n"
1035                               << "  * Result: " << attr.getType() << "\n\n");
1036       memory[memIndex] = attr.getType().getAsOpaquePointer();
1037       break;
1038     }
1039     case GetDefiningOp: {
1040       LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
1041       unsigned memIndex = read();
1042       Value value = read<Value>();
1043       Operation *op = value ? value.getDefiningOp() : nullptr;
1044 
1045       LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n"
1046                               << "  * Result: " << *op << "\n\n");
1047       memory[memIndex] = op;
1048       break;
1049     }
1050     case GetOperand0:
1051     case GetOperand1:
1052     case GetOperand2:
1053     case GetOperand3:
1054     case GetOperandN: {
1055       LLVM_DEBUG({
1056         llvm::dbgs() << "Executing GetOperand"
1057                      << (opCode == GetOperandN ? Twine("N")
1058                                                : Twine(opCode - GetOperand0))
1059                      << ":\n";
1060       });
1061       unsigned index =
1062           opCode == GetOperandN ? read<uint32_t>() : (opCode - GetOperand0);
1063       Operation *op = read<Operation *>();
1064       unsigned memIndex = read();
1065       Value operand =
1066           index < op->getNumOperands() ? op->getOperand(index) : Value();
1067 
1068       LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1069                               << "  * Index: " << index << "\n"
1070                               << "  * Result: " << operand << "\n\n");
1071       memory[memIndex] = operand.getAsOpaquePointer();
1072       break;
1073     }
1074     case GetResult0:
1075     case GetResult1:
1076     case GetResult2:
1077     case GetResult3:
1078     case GetResultN: {
1079       LLVM_DEBUG({
1080         llvm::dbgs() << "Executing GetResult"
1081                      << (opCode == GetResultN ? Twine("N")
1082                                               : Twine(opCode - GetResult0))
1083                      << ":\n";
1084       });
1085       unsigned index =
1086           opCode == GetResultN ? read<uint32_t>() : (opCode - GetResult0);
1087       Operation *op = read<Operation *>();
1088       unsigned memIndex = read();
1089       OpResult result =
1090           index < op->getNumResults() ? op->getResult(index) : OpResult();
1091 
1092       LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1093                               << "  * Index: " << index << "\n"
1094                               << "  * Result: " << result << "\n\n");
1095       memory[memIndex] = result.getAsOpaquePointer();
1096       break;
1097     }
1098     case GetValueType: {
1099       LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
1100       unsigned memIndex = read();
1101       Value value = read<Value>();
1102 
1103       LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n"
1104                               << "  * Result: " << value.getType() << "\n\n");
1105       memory[memIndex] = value.getType().getAsOpaquePointer();
1106       break;
1107     }
1108     case IsNotNull: {
1109       LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
1110       const void *value = read<const void *>();
1111 
1112       LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n\n");
1113       selectJump(value != nullptr);
1114       break;
1115     }
1116     case RecordMatch: {
1117       LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
1118       assert(matches &&
1119              "expected matches to be provided when executing the matcher");
1120       unsigned patternIndex = read();
1121       PatternBenefit benefit = currentPatternBenefits[patternIndex];
1122       const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
1123 
1124       // If the benefit of the pattern is impossible, skip the processing of the
1125       // rest of the pattern.
1126       if (benefit.isImpossibleToMatch()) {
1127         LLVM_DEBUG(llvm::dbgs() << "  * Benefit: Impossible To Match\n\n");
1128         curCodeIt = dest;
1129         break;
1130       }
1131 
1132       // Create a fused location containing the locations of each of the
1133       // operations used in the match. This will be used as the location for
1134       // created operations during the rewrite that don't already have an
1135       // explicit location set.
1136       unsigned numMatchLocs = read();
1137       SmallVector<Location, 4> matchLocs;
1138       matchLocs.reserve(numMatchLocs);
1139       for (unsigned i = 0; i != numMatchLocs; ++i)
1140         matchLocs.push_back(read<Operation *>()->getLoc());
1141       Location matchLoc = rewriter.getFusedLoc(matchLocs);
1142 
1143       LLVM_DEBUG(llvm::dbgs() << "  * Benefit: " << benefit.getBenefit() << "\n"
1144                               << "  * Location: " << matchLoc << "\n\n");
1145       matches->emplace_back(matchLoc, patterns[patternIndex], benefit);
1146       readList<const void *>(matches->back().values);
1147       curCodeIt = dest;
1148       break;
1149     }
1150     case ReplaceOp: {
1151       LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
1152       Operation *op = read<Operation *>();
1153       SmallVector<Value, 16> args;
1154       readList<Value>(args);
1155 
1156       LLVM_DEBUG({
1157         llvm::dbgs() << "  * Operation: " << *op << "\n"
1158                      << "  * Values: ";
1159         llvm::interleaveComma(args, llvm::dbgs());
1160         llvm::dbgs() << "\n\n";
1161       });
1162       rewriter.replaceOp(op, args);
1163       break;
1164     }
1165     case SwitchAttribute: {
1166       LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
1167       Attribute value = read<Attribute>();
1168       ArrayAttr cases = read<ArrayAttr>();
1169       handleSwitch(value, cases);
1170       break;
1171     }
1172     case SwitchOperandCount: {
1173       LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
1174       Operation *op = read<Operation *>();
1175       auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1176 
1177       LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1178       handleSwitch(op->getNumOperands(), cases);
1179       break;
1180     }
1181     case SwitchOperationName: {
1182       LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
1183       OperationName value = read<Operation *>()->getName();
1184       size_t caseCount = read();
1185 
1186       // The operation names are stored in-line, so to print them out for
1187       // debugging purposes we need to read the array before executing the
1188       // switch so that we can display all of the possible values.
1189       LLVM_DEBUG({
1190         const ByteCodeField *prevCodeIt = curCodeIt;
1191         llvm::dbgs() << "  * Value: " << value << "\n"
1192                      << "  * Cases: ";
1193         llvm::interleaveComma(
1194             llvm::map_range(llvm::seq<size_t>(0, caseCount),
1195                             [&](size_t i) { return read<OperationName>(); }),
1196             llvm::dbgs());
1197         llvm::dbgs() << "\n\n";
1198         curCodeIt = prevCodeIt;
1199       });
1200 
1201       // Try to find the switch value within any of the cases.
1202       size_t jumpDest = 0;
1203       for (size_t i = 0; i != caseCount; ++i) {
1204         if (read<OperationName>() == value) {
1205           curCodeIt += (caseCount - i - 1);
1206           jumpDest = i + 1;
1207           break;
1208         }
1209       }
1210       selectJump(jumpDest);
1211       break;
1212     }
1213     case SwitchResultCount: {
1214       LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
1215       Operation *op = read<Operation *>();
1216       auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1217 
1218       LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1219       handleSwitch(op->getNumResults(), cases);
1220       break;
1221     }
1222     case SwitchType: {
1223       LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
1224       Type value = read<Type>();
1225       auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
1226       handleSwitch(value, cases);
1227       break;
1228     }
1229     }
1230   }
1231 }
1232 
1233 /// Run the pattern matcher on the given root operation, collecting the matched
1234 /// patterns in `matches`.
match(Operation * op,PatternRewriter & rewriter,SmallVectorImpl<MatchResult> & matches,PDLByteCodeMutableState & state) const1235 void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
1236                         SmallVectorImpl<MatchResult> &matches,
1237                         PDLByteCodeMutableState &state) const {
1238   // The first memory slot is always the root operation.
1239   state.memory[0] = op;
1240 
1241   // The matcher function always starts at code address 0.
1242   ByteCodeExecutor executor(matcherByteCode.data(), state.memory, uniquedData,
1243                             matcherByteCode, state.currentPatternBenefits,
1244                             patterns, constraintFunctions, createFunctions,
1245                             rewriteFunctions);
1246   executor.execute(rewriter, &matches);
1247 
1248   // Order the found matches by benefit.
1249   std::stable_sort(matches.begin(), matches.end(),
1250                    [](const MatchResult &lhs, const MatchResult &rhs) {
1251                      return lhs.benefit > rhs.benefit;
1252                    });
1253 }
1254 
1255 /// Run the rewriter of the given pattern on the root operation `op`.
rewrite(PatternRewriter & rewriter,const MatchResult & match,PDLByteCodeMutableState & state) const1256 void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match,
1257                           PDLByteCodeMutableState &state) const {
1258   // The arguments of the rewrite function are stored at the start of the
1259   // memory buffer.
1260   llvm::copy(match.values, state.memory.begin());
1261 
1262   ByteCodeExecutor executor(
1263       &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
1264       uniquedData, rewriterByteCode, state.currentPatternBenefits, patterns,
1265       constraintFunctions, createFunctions, rewriteFunctions);
1266   executor.execute(rewriter, /*matches=*/nullptr, match.location);
1267 }
1268