1 //===- ByteCode.h - Pattern byte-code interpreter ---------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file declares a byte-code and interpreter for pattern rewrites in MLIR. 10 // The byte-code is constructed from the PDL Interpreter dialect. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef MLIR_REWRITE_BYTECODE_H_ 15 #define MLIR_REWRITE_BYTECODE_H_ 16 17 #include "mlir/IR/PatternMatch.h" 18 19 namespace mlir { 20 namespace pdl_interp { 21 class RecordMatchOp; 22 } // end namespace pdl_interp 23 24 namespace detail { 25 class PDLByteCode; 26 27 /// Use generic bytecode types. ByteCodeField refers to the actual bytecode 28 /// entries. ByteCodeAddr refers to size of indices into the bytecode. 29 using ByteCodeField = uint16_t; 30 using ByteCodeAddr = uint32_t; 31 32 //===----------------------------------------------------------------------===// 33 // PDLByteCodePattern 34 //===----------------------------------------------------------------------===// 35 36 /// All of the data pertaining to a specific pattern within the bytecode. 37 class PDLByteCodePattern : public Pattern { 38 public: 39 static PDLByteCodePattern create(pdl_interp::RecordMatchOp matchOp, 40 ByteCodeAddr rewriterAddr); 41 42 /// Return the bytecode address of the rewriter for this pattern. getRewriterAddr()43 ByteCodeAddr getRewriterAddr() const { return rewriterAddr; } 44 45 private: 46 template <typename... Args> PDLByteCodePattern(ByteCodeAddr rewriterAddr,Args &&...patternArgs)47 PDLByteCodePattern(ByteCodeAddr rewriterAddr, Args &&...patternArgs) 48 : Pattern(std::forward<Args>(patternArgs)...), 49 rewriterAddr(rewriterAddr) {} 50 51 /// The address of the rewriter for this pattern. 52 ByteCodeAddr rewriterAddr; 53 }; 54 55 //===----------------------------------------------------------------------===// 56 // PDLByteCodeMutableState 57 //===----------------------------------------------------------------------===// 58 59 /// This class contains the mutable state of a bytecode instance. This allows 60 /// for a bytecode instance to be cached and reused across various different 61 /// threads/drivers. 62 class PDLByteCodeMutableState { 63 public: 64 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds 65 /// to the position of the pattern within the range returned by 66 /// `PDLByteCode::getPatterns`. 67 void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit); 68 69 /// Cleanup any allocated state after a match/rewrite has been completed. This 70 /// method should be called irregardless of whether the match+rewrite was a 71 /// success or not. 72 void cleanupAfterMatchAndRewrite(); 73 74 private: 75 /// Allow access to data fields. 76 friend class PDLByteCode; 77 78 /// The mutable block of memory used during the matching and rewriting phases 79 /// of the bytecode. 80 std::vector<const void *> memory; 81 82 /// A mutable block of memory used during the matching and rewriting phase of 83 /// the bytecode to store ranges of types. 84 std::vector<TypeRange> typeRangeMemory; 85 /// A set of type ranges that have been allocated by the byte code interpreter 86 /// to provide a guaranteed lifetime. 87 std::vector<llvm::OwningArrayRef<Type>> allocatedTypeRangeMemory; 88 89 /// A mutable block of memory used during the matching and rewriting phase of 90 /// the bytecode to store ranges of values. 91 std::vector<ValueRange> valueRangeMemory; 92 /// A set of value ranges that have been allocated by the byte code 93 /// interpreter to provide a guaranteed lifetime. 94 std::vector<llvm::OwningArrayRef<Value>> allocatedValueRangeMemory; 95 96 /// The up-to-date benefits of the patterns held by the bytecode. The order 97 /// of this array corresponds 1-1 with the array of patterns in `PDLByteCode`. 98 std::vector<PatternBenefit> currentPatternBenefits; 99 }; 100 101 //===----------------------------------------------------------------------===// 102 // PDLByteCode 103 //===----------------------------------------------------------------------===// 104 105 /// The bytecode class is also the interpreter. Contains the bytecode itself, 106 /// the static info, addresses of the rewriter functions, the interpreter 107 /// memory buffer, and the execution context. 108 class PDLByteCode { 109 public: 110 /// Each successful match returns a MatchResult, which contains information 111 /// necessary to execute the rewriter and indicates the originating pattern. 112 struct MatchResult { MatchResultMatchResult113 MatchResult(Location loc, const PDLByteCodePattern &pattern, 114 PatternBenefit benefit) 115 : location(loc), pattern(&pattern), benefit(benefit) {} 116 MatchResult(const MatchResult &) = delete; 117 MatchResult &operator=(const MatchResult &) = delete; 118 MatchResult(MatchResult &&other) = default; 119 MatchResult &operator=(MatchResult &&) = default; 120 121 /// The location of operations to be replaced. 122 Location location; 123 /// Memory values defined in the matcher that are passed to the rewriter. 124 SmallVector<const void *> values; 125 /// Memory used for the range input values. 126 SmallVector<TypeRange, 0> typeRangeValues; 127 SmallVector<ValueRange, 0> valueRangeValues; 128 129 /// The originating pattern that was matched. This is always non-null, but 130 /// represented with a pointer to allow for assignment. 131 const PDLByteCodePattern *pattern; 132 /// The current benefit of the pattern that was matched. 133 PatternBenefit benefit; 134 }; 135 136 /// Create a ByteCode instance from the given module containing operations in 137 /// the PDL interpreter dialect. 138 PDLByteCode(ModuleOp module, 139 llvm::StringMap<PDLConstraintFunction> constraintFns, 140 llvm::StringMap<PDLRewriteFunction> rewriteFns); 141 142 /// Return the patterns held by the bytecode. getPatterns()143 ArrayRef<PDLByteCodePattern> getPatterns() const { return patterns; } 144 145 /// Initialize the given state such that it can be used to execute the current 146 /// bytecode. 147 void initializeMutableState(PDLByteCodeMutableState &state) const; 148 149 /// Run the pattern matcher on the given root operation, collecting the 150 /// matched patterns in `matches`. 151 void match(Operation *op, PatternRewriter &rewriter, 152 SmallVectorImpl<MatchResult> &matches, 153 PDLByteCodeMutableState &state) const; 154 155 /// Run the rewriter of the given pattern that was previously matched in 156 /// `match`. 157 void rewrite(PatternRewriter &rewriter, const MatchResult &match, 158 PDLByteCodeMutableState &state) const; 159 160 private: 161 /// Execute the given byte code starting at the provided instruction `inst`. 162 /// `matches` is an optional field provided when this function is executed in 163 /// a matching context. 164 void executeByteCode(const ByteCodeField *inst, PatternRewriter &rewriter, 165 PDLByteCodeMutableState &state, 166 SmallVectorImpl<MatchResult> *matches) const; 167 168 /// A vector containing pointers to uniqued data. The storage is intentionally 169 /// opaque such that we can store a wide range of data types. The types of 170 /// data stored here include: 171 /// * Attribute, Identifier, OperationName, Type 172 std::vector<const void *> uniquedData; 173 174 /// A vector containing the generated bytecode for the matcher. 175 SmallVector<ByteCodeField, 64> matcherByteCode; 176 177 /// A vector containing the generated bytecode for all of the rewriters. 178 SmallVector<ByteCodeField, 64> rewriterByteCode; 179 180 /// The set of patterns contained within the bytecode. 181 SmallVector<PDLByteCodePattern, 32> patterns; 182 183 /// A set of user defined functions invoked via PDL. 184 std::vector<PDLConstraintFunction> constraintFunctions; 185 std::vector<PDLRewriteFunction> rewriteFunctions; 186 187 /// The maximum memory index used by a value. 188 ByteCodeField maxValueMemoryIndex = 0; 189 190 /// The maximum number of different types of ranges. 191 ByteCodeField maxTypeRangeCount = 0; 192 ByteCodeField maxValueRangeCount = 0; 193 }; 194 195 } // end namespace detail 196 } // end namespace mlir 197 198 #endif // MLIR_REWRITE_BYTECODE_H_ 199