1 //===- SPIRVCanonicalization.cpp - MLIR SPIR-V canonicalization patterns --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the folders and canonicalization patterns for SPIR-V ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
14 
15 #include "mlir/Dialect/CommonFolders.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
18 #include "mlir/IR/Matchers.h"
19 #include "mlir/IR/PatternMatch.h"
20 
21 using namespace mlir;
22 
23 //===----------------------------------------------------------------------===//
24 // Common utility functions
25 //===----------------------------------------------------------------------===//
26 
27 /// Returns the boolean value under the hood if the given `boolAttr` is a scalar
28 /// or splat vector bool constant.
getScalarOrSplatBoolAttr(Attribute boolAttr)29 static Optional<bool> getScalarOrSplatBoolAttr(Attribute boolAttr) {
30   if (!boolAttr)
31     return llvm::None;
32 
33   auto type = boolAttr.getType();
34   if (type.isInteger(1)) {
35     auto attr = boolAttr.cast<BoolAttr>();
36     return attr.getValue();
37   }
38   if (auto vecType = type.cast<VectorType>()) {
39     if (vecType.getElementType().isInteger(1))
40       if (auto attr = boolAttr.dyn_cast<SplatElementsAttr>())
41         return attr.getSplatValue<bool>();
42   }
43   return llvm::None;
44 }
45 
46 // Extracts an element from the given `composite` by following the given
47 // `indices`. Returns a null Attribute if error happens.
extractCompositeElement(Attribute composite,ArrayRef<unsigned> indices)48 static Attribute extractCompositeElement(Attribute composite,
49                                          ArrayRef<unsigned> indices) {
50   // Check that given composite is a constant.
51   if (!composite)
52     return {};
53   // Return composite itself if we reach the end of the index chain.
54   if (indices.empty())
55     return composite;
56 
57   if (auto vector = composite.dyn_cast<ElementsAttr>()) {
58     assert(indices.size() == 1 && "must have exactly one index for a vector");
59     return vector.getValue({indices[0]});
60   }
61 
62   if (auto array = composite.dyn_cast<ArrayAttr>()) {
63     assert(!indices.empty() && "must have at least one index for an array");
64     return extractCompositeElement(array.getValue()[indices[0]],
65                                    indices.drop_front());
66   }
67 
68   return {};
69 }
70 
71 //===----------------------------------------------------------------------===//
72 // TableGen'erated canonicalizers
73 //===----------------------------------------------------------------------===//
74 
75 namespace {
76 #include "SPIRVCanonicalization.inc"
77 }
78 
79 //===----------------------------------------------------------------------===//
80 // spv.AccessChainOp
81 //===----------------------------------------------------------------------===//
82 
83 namespace {
84 
85 /// Combines chained `spirv::AccessChainOp` operations into one
86 /// `spirv::AccessChainOp` operation.
87 struct CombineChainedAccessChain
88     : public OpRewritePattern<spirv::AccessChainOp> {
89   using OpRewritePattern<spirv::AccessChainOp>::OpRewritePattern;
90 
matchAndRewrite__anon8d29c6030211::CombineChainedAccessChain91   LogicalResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
92                                 PatternRewriter &rewriter) const override {
93     auto parentAccessChainOp = dyn_cast_or_null<spirv::AccessChainOp>(
94         accessChainOp.base_ptr().getDefiningOp());
95 
96     if (!parentAccessChainOp) {
97       return failure();
98     }
99 
100     // Combine indices.
101     SmallVector<Value, 4> indices(parentAccessChainOp.indices());
102     indices.append(accessChainOp.indices().begin(),
103                    accessChainOp.indices().end());
104 
105     rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
106         accessChainOp, parentAccessChainOp.base_ptr(), indices);
107 
108     return success();
109   }
110 };
111 } // end anonymous namespace
112 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)113 void spirv::AccessChainOp::getCanonicalizationPatterns(
114     RewritePatternSet &results, MLIRContext *context) {
115   results.add<CombineChainedAccessChain>(context);
116 }
117 
118 //===----------------------------------------------------------------------===//
119 // spv.BitcastOp
120 //===----------------------------------------------------------------------===//
121 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)122 void spirv::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
123                                                    MLIRContext *context) {
124   results.add<ConvertChainedBitcast>(context);
125 }
126 
127 //===----------------------------------------------------------------------===//
128 // spv.CompositeExtractOp
129 //===----------------------------------------------------------------------===//
130 
fold(ArrayRef<Attribute> operands)131 OpFoldResult spirv::CompositeExtractOp::fold(ArrayRef<Attribute> operands) {
132   assert(operands.size() == 1 && "spv.CompositeExtract expects one operand");
133   auto indexVector =
134       llvm::to_vector<8>(llvm::map_range(indices(), [](Attribute attr) {
135         return static_cast<unsigned>(attr.cast<IntegerAttr>().getInt());
136       }));
137   return extractCompositeElement(operands[0], indexVector);
138 }
139 
140 //===----------------------------------------------------------------------===//
141 // spv.Constant
142 //===----------------------------------------------------------------------===//
143 
fold(ArrayRef<Attribute> operands)144 OpFoldResult spirv::ConstantOp::fold(ArrayRef<Attribute> operands) {
145   assert(operands.empty() && "spv.Constant has no operands");
146   return value();
147 }
148 
149 //===----------------------------------------------------------------------===//
150 // spv.IAdd
151 //===----------------------------------------------------------------------===//
152 
fold(ArrayRef<Attribute> operands)153 OpFoldResult spirv::IAddOp::fold(ArrayRef<Attribute> operands) {
154   assert(operands.size() == 2 && "spv.IAdd expects two operands");
155   // x + 0 = x
156   if (matchPattern(operand2(), m_Zero()))
157     return operand1();
158 
159   // According to the SPIR-V spec:
160   //
161   // The resulting value will equal the low-order N bits of the correct result
162   // R, where N is the component width and R is computed with enough precision
163   // to avoid overflow and underflow.
164   return constFoldBinaryOp<IntegerAttr>(operands,
165                                         [](APInt a, APInt b) { return a + b; });
166 }
167 
168 //===----------------------------------------------------------------------===//
169 // spv.IMul
170 //===----------------------------------------------------------------------===//
171 
fold(ArrayRef<Attribute> operands)172 OpFoldResult spirv::IMulOp::fold(ArrayRef<Attribute> operands) {
173   assert(operands.size() == 2 && "spv.IMul expects two operands");
174   // x * 0 == 0
175   if (matchPattern(operand2(), m_Zero()))
176     return operand2();
177   // x * 1 = x
178   if (matchPattern(operand2(), m_One()))
179     return operand1();
180 
181   // According to the SPIR-V spec:
182   //
183   // The resulting value will equal the low-order N bits of the correct result
184   // R, where N is the component width and R is computed with enough precision
185   // to avoid overflow and underflow.
186   return constFoldBinaryOp<IntegerAttr>(operands,
187                                         [](APInt a, APInt b) { return a * b; });
188 }
189 
190 //===----------------------------------------------------------------------===//
191 // spv.ISub
192 //===----------------------------------------------------------------------===//
193 
fold(ArrayRef<Attribute> operands)194 OpFoldResult spirv::ISubOp::fold(ArrayRef<Attribute> operands) {
195   // x - x = 0
196   if (operand1() == operand2())
197     return Builder(getContext()).getIntegerAttr(getType(), 0);
198 
199   // According to the SPIR-V spec:
200   //
201   // The resulting value will equal the low-order N bits of the correct result
202   // R, where N is the component width and R is computed with enough precision
203   // to avoid overflow and underflow.
204   return constFoldBinaryOp<IntegerAttr>(operands,
205                                         [](APInt a, APInt b) { return a - b; });
206 }
207 
208 //===----------------------------------------------------------------------===//
209 // spv.LogicalAnd
210 //===----------------------------------------------------------------------===//
211 
fold(ArrayRef<Attribute> operands)212 OpFoldResult spirv::LogicalAndOp::fold(ArrayRef<Attribute> operands) {
213   assert(operands.size() == 2 && "spv.LogicalAnd should take two operands");
214 
215   if (Optional<bool> rhs = getScalarOrSplatBoolAttr(operands.back())) {
216     // x && true = x
217     if (rhs.getValue())
218       return operand1();
219 
220     // x && false = false
221     if (!rhs.getValue())
222       return operands.back();
223   }
224 
225   return Attribute();
226 }
227 
228 //===----------------------------------------------------------------------===//
229 // spv.LogicalNot
230 //===----------------------------------------------------------------------===//
231 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)232 void spirv::LogicalNotOp::getCanonicalizationPatterns(
233     RewritePatternSet &results, MLIRContext *context) {
234   results
235       .add<ConvertLogicalNotOfIEqual, ConvertLogicalNotOfINotEqual,
236            ConvertLogicalNotOfLogicalEqual, ConvertLogicalNotOfLogicalNotEqual>(
237           context);
238 }
239 
240 //===----------------------------------------------------------------------===//
241 // spv.LogicalOr
242 //===----------------------------------------------------------------------===//
243 
fold(ArrayRef<Attribute> operands)244 OpFoldResult spirv::LogicalOrOp::fold(ArrayRef<Attribute> operands) {
245   assert(operands.size() == 2 && "spv.LogicalOr should take two operands");
246 
247   if (auto rhs = getScalarOrSplatBoolAttr(operands.back())) {
248     if (rhs.getValue())
249       // x || true = true
250       return operands.back();
251 
252     // x || false = x
253     if (!rhs.getValue())
254       return operand1();
255   }
256 
257   return Attribute();
258 }
259 
260 //===----------------------------------------------------------------------===//
261 // spv.mlir.selection
262 //===----------------------------------------------------------------------===//
263 
264 namespace {
265 // Blocks from the given `spv.mlir.selection` operation must satisfy the
266 // following layout:
267 //
268 //       +-----------------------------------------------+
269 //       | header block                                  |
270 //       | spv.BranchConditionalOp %cond, ^case0, ^case1 |
271 //       +-----------------------------------------------+
272 //                            /   \
273 //                             ...
274 //
275 //
276 //   +------------------------+    +------------------------+
277 //   | case #0                |    | case #1                |
278 //   | spv.Store %ptr %value0 |    | spv.Store %ptr %value1 |
279 //   | spv.Branch ^merge      |    | spv.Branch ^merge      |
280 //   +------------------------+    +------------------------+
281 //
282 //
283 //                             ...
284 //                            \   /
285 //                              v
286 //                       +-------------+
287 //                       | merge block |
288 //                       +-------------+
289 //
290 struct ConvertSelectionOpToSelect
291     : public OpRewritePattern<spirv::SelectionOp> {
292   using OpRewritePattern<spirv::SelectionOp>::OpRewritePattern;
293 
matchAndRewrite__anon8d29c6030711::ConvertSelectionOpToSelect294   LogicalResult matchAndRewrite(spirv::SelectionOp selectionOp,
295                                 PatternRewriter &rewriter) const override {
296     auto *op = selectionOp.getOperation();
297     auto &body = op->getRegion(0);
298     // Verifier allows an empty region for `spv.mlir.selection`.
299     if (body.empty()) {
300       return failure();
301     }
302 
303     // Check that region consists of 4 blocks:
304     // header block, `true` block, `false` block and merge block.
305     if (std::distance(body.begin(), body.end()) != 4) {
306       return failure();
307     }
308 
309     auto *headerBlock = selectionOp.getHeaderBlock();
310     if (!onlyContainsBranchConditionalOp(headerBlock)) {
311       return failure();
312     }
313 
314     auto brConditionalOp =
315         cast<spirv::BranchConditionalOp>(headerBlock->front());
316 
317     auto *trueBlock = brConditionalOp.getSuccessor(0);
318     auto *falseBlock = brConditionalOp.getSuccessor(1);
319     auto *mergeBlock = selectionOp.getMergeBlock();
320 
321     if (failed(canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)))
322       return failure();
323 
324     auto trueValue = getSrcValue(trueBlock);
325     auto falseValue = getSrcValue(falseBlock);
326     auto ptrValue = getDstPtr(trueBlock);
327     auto storeOpAttributes =
328         cast<spirv::StoreOp>(trueBlock->front())->getAttrs();
329 
330     auto selectOp = rewriter.create<spirv::SelectOp>(
331         selectionOp.getLoc(), trueValue.getType(), brConditionalOp.condition(),
332         trueValue, falseValue);
333     rewriter.create<spirv::StoreOp>(selectOp.getLoc(), ptrValue,
334                                     selectOp.getResult(), storeOpAttributes);
335 
336     // `spv.mlir.selection` is not needed anymore.
337     rewriter.eraseOp(op);
338     return success();
339   }
340 
341 private:
342   // Checks that given blocks follow the following rules:
343   // 1. Each conditional block consists of two operations, the first operation
344   //    is a `spv.Store` and the last operation is a `spv.Branch`.
345   // 2. Each `spv.Store` uses the same pointer and the same memory attributes.
346   // 3. A control flow goes into the given merge block from the given
347   //    conditional blocks.
348   LogicalResult canCanonicalizeSelection(Block *trueBlock, Block *falseBlock,
349                                          Block *mergeBlock) const;
350 
onlyContainsBranchConditionalOp__anon8d29c6030711::ConvertSelectionOpToSelect351   bool onlyContainsBranchConditionalOp(Block *block) const {
352     return std::next(block->begin()) == block->end() &&
353            isa<spirv::BranchConditionalOp>(block->front());
354   }
355 
isSameAttrList__anon8d29c6030711::ConvertSelectionOpToSelect356   bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs) const {
357     return lhs->getAttrDictionary() == rhs->getAttrDictionary();
358   }
359 
360   // Returns a source value for the given block.
getSrcValue__anon8d29c6030711::ConvertSelectionOpToSelect361   Value getSrcValue(Block *block) const {
362     auto storeOp = cast<spirv::StoreOp>(block->front());
363     return storeOp.value();
364   }
365 
366   // Returns a destination value for the given block.
getDstPtr__anon8d29c6030711::ConvertSelectionOpToSelect367   Value getDstPtr(Block *block) const {
368     auto storeOp = cast<spirv::StoreOp>(block->front());
369     return storeOp.ptr();
370   }
371 };
372 
canCanonicalizeSelection(Block * trueBlock,Block * falseBlock,Block * mergeBlock) const373 LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
374     Block *trueBlock, Block *falseBlock, Block *mergeBlock) const {
375   // Each block must consists of 2 operations.
376   if ((std::distance(trueBlock->begin(), trueBlock->end()) != 2) ||
377       (std::distance(falseBlock->begin(), falseBlock->end()) != 2)) {
378     return failure();
379   }
380 
381   auto trueBrStoreOp = dyn_cast<spirv::StoreOp>(trueBlock->front());
382   auto trueBrBranchOp =
383       dyn_cast<spirv::BranchOp>(*std::next(trueBlock->begin()));
384   auto falseBrStoreOp = dyn_cast<spirv::StoreOp>(falseBlock->front());
385   auto falseBrBranchOp =
386       dyn_cast<spirv::BranchOp>(*std::next(falseBlock->begin()));
387 
388   if (!trueBrStoreOp || !trueBrBranchOp || !falseBrStoreOp ||
389       !falseBrBranchOp) {
390     return failure();
391   }
392 
393   // Checks that given type is valid for `spv.SelectOp`.
394   // According to SPIR-V spec:
395   // "Before version 1.4, Result Type must be a pointer, scalar, or vector.
396   // Starting with version 1.4, Result Type can additionally be a composite type
397   // other than a vector."
398   bool isScalarOrVector = trueBrStoreOp.value()
399                               .getType()
400                               .cast<spirv::SPIRVType>()
401                               .isScalarOrVector();
402 
403   // Check that each `spv.Store` uses the same pointer, memory access
404   // attributes and a valid type of the value.
405   if ((trueBrStoreOp.ptr() != falseBrStoreOp.ptr()) ||
406       !isSameAttrList(trueBrStoreOp, falseBrStoreOp) || !isScalarOrVector) {
407     return failure();
408   }
409 
410   if ((trueBrBranchOp->getSuccessor(0) != mergeBlock) ||
411       (falseBrBranchOp->getSuccessor(0) != mergeBlock)) {
412     return failure();
413   }
414 
415   return success();
416 }
417 } // end anonymous namespace
418 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)419 void spirv::SelectionOp::getCanonicalizationPatterns(RewritePatternSet &results,
420                                                      MLIRContext *context) {
421   results.add<ConvertSelectionOpToSelect>(context);
422 }
423