1 //===- ByteCode.h - Pattern byte-code interpreter ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file declares a byte-code and interpreter for pattern rewrites in MLIR.
10 // The byte-code is constructed from the PDL Interpreter dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_REWRITE_BYTECODE_H_
15 #define MLIR_REWRITE_BYTECODE_H_
16 
17 #include "mlir/IR/PatternMatch.h"
18 
19 namespace mlir {
20 namespace pdl_interp {
21 class RecordMatchOp;
22 } // end namespace pdl_interp
23 
24 namespace detail {
25 class PDLByteCode;
26 
27 /// Use generic bytecode types. ByteCodeField refers to the actual bytecode
28 /// entries. ByteCodeAddr refers to size of indices into the bytecode.
29 using ByteCodeField = uint16_t;
30 using ByteCodeAddr = uint32_t;
31 
32 //===----------------------------------------------------------------------===//
33 // PDLByteCodePattern
34 //===----------------------------------------------------------------------===//
35 
36 /// All of the data pertaining to a specific pattern within the bytecode.
37 class PDLByteCodePattern : public Pattern {
38 public:
39   static PDLByteCodePattern create(pdl_interp::RecordMatchOp matchOp,
40                                    ByteCodeAddr rewriterAddr);
41 
42   /// Return the bytecode address of the rewriter for this pattern.
getRewriterAddr()43   ByteCodeAddr getRewriterAddr() const { return rewriterAddr; }
44 
45 private:
46   template <typename... Args>
PDLByteCodePattern(ByteCodeAddr rewriterAddr,Args &&...patternArgs)47   PDLByteCodePattern(ByteCodeAddr rewriterAddr, Args &&...patternArgs)
48       : Pattern(std::forward<Args>(patternArgs)...),
49         rewriterAddr(rewriterAddr) {}
50 
51   /// The address of the rewriter for this pattern.
52   ByteCodeAddr rewriterAddr;
53 };
54 
55 //===----------------------------------------------------------------------===//
56 // PDLByteCodeMutableState
57 //===----------------------------------------------------------------------===//
58 
59 /// This class contains the mutable state of a bytecode instance. This allows
60 /// for a bytecode instance to be cached and reused across various different
61 /// threads/drivers.
62 class PDLByteCodeMutableState {
63 public:
64   /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
65   /// to the position of the pattern within the range returned by
66   /// `PDLByteCode::getPatterns`.
67   void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit);
68 
69   /// Cleanup any allocated state after a match/rewrite has been completed. This
70   /// method should be called irregardless of whether the match+rewrite was a
71   /// success or not.
72   void cleanupAfterMatchAndRewrite();
73 
74 private:
75   /// Allow access to data fields.
76   friend class PDLByteCode;
77 
78   /// The mutable block of memory used during the matching and rewriting phases
79   /// of the bytecode.
80   std::vector<const void *> memory;
81 
82   /// A mutable block of memory used during the matching and rewriting phase of
83   /// the bytecode to store ranges of types.
84   std::vector<TypeRange> typeRangeMemory;
85   /// A set of type ranges that have been allocated by the byte code interpreter
86   /// to provide a guaranteed lifetime.
87   std::vector<llvm::OwningArrayRef<Type>> allocatedTypeRangeMemory;
88 
89   /// A mutable block of memory used during the matching and rewriting phase of
90   /// the bytecode to store ranges of values.
91   std::vector<ValueRange> valueRangeMemory;
92   /// A set of value ranges that have been allocated by the byte code
93   /// interpreter to provide a guaranteed lifetime.
94   std::vector<llvm::OwningArrayRef<Value>> allocatedValueRangeMemory;
95 
96   /// The up-to-date benefits of the patterns held by the bytecode. The order
97   /// of this array corresponds 1-1 with the array of patterns in `PDLByteCode`.
98   std::vector<PatternBenefit> currentPatternBenefits;
99 };
100 
101 //===----------------------------------------------------------------------===//
102 // PDLByteCode
103 //===----------------------------------------------------------------------===//
104 
105 /// The bytecode class is also the interpreter. Contains the bytecode itself,
106 /// the static info, addresses of the rewriter functions, the interpreter
107 /// memory buffer, and the execution context.
108 class PDLByteCode {
109 public:
110   /// Each successful match returns a MatchResult, which contains information
111   /// necessary to execute the rewriter and indicates the originating pattern.
112   struct MatchResult {
MatchResultMatchResult113     MatchResult(Location loc, const PDLByteCodePattern &pattern,
114                 PatternBenefit benefit)
115         : location(loc), pattern(&pattern), benefit(benefit) {}
116     MatchResult(const MatchResult &) = delete;
117     MatchResult &operator=(const MatchResult &) = delete;
118     MatchResult(MatchResult &&other) = default;
119     MatchResult &operator=(MatchResult &&) = default;
120 
121     /// The location of operations to be replaced.
122     Location location;
123     /// Memory values defined in the matcher that are passed to the rewriter.
124     SmallVector<const void *> values;
125     /// Memory used for the range input values.
126     SmallVector<TypeRange, 0> typeRangeValues;
127     SmallVector<ValueRange, 0> valueRangeValues;
128 
129     /// The originating pattern that was matched. This is always non-null, but
130     /// represented with a pointer to allow for assignment.
131     const PDLByteCodePattern *pattern;
132     /// The current benefit of the pattern that was matched.
133     PatternBenefit benefit;
134   };
135 
136   /// Create a ByteCode instance from the given module containing operations in
137   /// the PDL interpreter dialect.
138   PDLByteCode(ModuleOp module,
139               llvm::StringMap<PDLConstraintFunction> constraintFns,
140               llvm::StringMap<PDLRewriteFunction> rewriteFns);
141 
142   /// Return the patterns held by the bytecode.
getPatterns()143   ArrayRef<PDLByteCodePattern> getPatterns() const { return patterns; }
144 
145   /// Initialize the given state such that it can be used to execute the current
146   /// bytecode.
147   void initializeMutableState(PDLByteCodeMutableState &state) const;
148 
149   /// Run the pattern matcher on the given root operation, collecting the
150   /// matched patterns in `matches`.
151   void match(Operation *op, PatternRewriter &rewriter,
152              SmallVectorImpl<MatchResult> &matches,
153              PDLByteCodeMutableState &state) const;
154 
155   /// Run the rewriter of the given pattern that was previously matched in
156   /// `match`.
157   void rewrite(PatternRewriter &rewriter, const MatchResult &match,
158                PDLByteCodeMutableState &state) const;
159 
160 private:
161   /// Execute the given byte code starting at the provided instruction `inst`.
162   /// `matches` is an optional field provided when this function is executed in
163   /// a matching context.
164   void executeByteCode(const ByteCodeField *inst, PatternRewriter &rewriter,
165                        PDLByteCodeMutableState &state,
166                        SmallVectorImpl<MatchResult> *matches) const;
167 
168   /// A vector containing pointers to uniqued data. The storage is intentionally
169   /// opaque such that we can store a wide range of data types. The types of
170   /// data stored here include:
171   ///  * Attribute, Identifier, OperationName, Type
172   std::vector<const void *> uniquedData;
173 
174   /// A vector containing the generated bytecode for the matcher.
175   SmallVector<ByteCodeField, 64> matcherByteCode;
176 
177   /// A vector containing the generated bytecode for all of the rewriters.
178   SmallVector<ByteCodeField, 64> rewriterByteCode;
179 
180   /// The set of patterns contained within the bytecode.
181   SmallVector<PDLByteCodePattern, 32> patterns;
182 
183   /// A set of user defined functions invoked via PDL.
184   std::vector<PDLConstraintFunction> constraintFunctions;
185   std::vector<PDLRewriteFunction> rewriteFunctions;
186 
187   /// The maximum memory index used by a value.
188   ByteCodeField maxValueMemoryIndex = 0;
189 
190   /// The maximum number of different types of ranges.
191   ByteCodeField maxTypeRangeCount = 0;
192   ByteCodeField maxValueRangeCount = 0;
193 };
194 
195 } // end namespace detail
196 } // end namespace mlir
197 
198 #endif // MLIR_REWRITE_BYTECODE_H_
199