1 //===- SPIRVCanonicalization.cpp - MLIR SPIR-V canonicalization patterns --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the folders and canonicalization patterns for SPIR-V ops.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
14
15 #include "mlir/Dialect/CommonFolders.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
18 #include "mlir/IR/Matchers.h"
19 #include "mlir/IR/PatternMatch.h"
20
21 using namespace mlir;
22
23 //===----------------------------------------------------------------------===//
24 // Common utility functions
25 //===----------------------------------------------------------------------===//
26
27 /// Returns the boolean value under the hood if the given `boolAttr` is a scalar
28 /// or splat vector bool constant.
getScalarOrSplatBoolAttr(Attribute boolAttr)29 static Optional<bool> getScalarOrSplatBoolAttr(Attribute boolAttr) {
30 if (!boolAttr)
31 return llvm::None;
32
33 auto type = boolAttr.getType();
34 if (type.isInteger(1)) {
35 auto attr = boolAttr.cast<BoolAttr>();
36 return attr.getValue();
37 }
38 if (auto vecType = type.cast<VectorType>()) {
39 if (vecType.getElementType().isInteger(1))
40 if (auto attr = boolAttr.dyn_cast<SplatElementsAttr>())
41 return attr.getSplatValue<bool>();
42 }
43 return llvm::None;
44 }
45
46 // Extracts an element from the given `composite` by following the given
47 // `indices`. Returns a null Attribute if error happens.
extractCompositeElement(Attribute composite,ArrayRef<unsigned> indices)48 static Attribute extractCompositeElement(Attribute composite,
49 ArrayRef<unsigned> indices) {
50 // Check that given composite is a constant.
51 if (!composite)
52 return {};
53 // Return composite itself if we reach the end of the index chain.
54 if (indices.empty())
55 return composite;
56
57 if (auto vector = composite.dyn_cast<ElementsAttr>()) {
58 assert(indices.size() == 1 && "must have exactly one index for a vector");
59 return vector.getValue({indices[0]});
60 }
61
62 if (auto array = composite.dyn_cast<ArrayAttr>()) {
63 assert(!indices.empty() && "must have at least one index for an array");
64 return extractCompositeElement(array.getValue()[indices[0]],
65 indices.drop_front());
66 }
67
68 return {};
69 }
70
71 //===----------------------------------------------------------------------===//
72 // TableGen'erated canonicalizers
73 //===----------------------------------------------------------------------===//
74
75 namespace {
76 #include "SPIRVCanonicalization.inc"
77 }
78
79 //===----------------------------------------------------------------------===//
80 // spv.AccessChainOp
81 //===----------------------------------------------------------------------===//
82
83 namespace {
84
85 /// Combines chained `spirv::AccessChainOp` operations into one
86 /// `spirv::AccessChainOp` operation.
87 struct CombineChainedAccessChain
88 : public OpRewritePattern<spirv::AccessChainOp> {
89 using OpRewritePattern<spirv::AccessChainOp>::OpRewritePattern;
90
matchAndRewrite__anon8d29c6030211::CombineChainedAccessChain91 LogicalResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
92 PatternRewriter &rewriter) const override {
93 auto parentAccessChainOp = dyn_cast_or_null<spirv::AccessChainOp>(
94 accessChainOp.base_ptr().getDefiningOp());
95
96 if (!parentAccessChainOp) {
97 return failure();
98 }
99
100 // Combine indices.
101 SmallVector<Value, 4> indices(parentAccessChainOp.indices());
102 indices.append(accessChainOp.indices().begin(),
103 accessChainOp.indices().end());
104
105 rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
106 accessChainOp, parentAccessChainOp.base_ptr(), indices);
107
108 return success();
109 }
110 };
111 } // end anonymous namespace
112
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)113 void spirv::AccessChainOp::getCanonicalizationPatterns(
114 RewritePatternSet &results, MLIRContext *context) {
115 results.add<CombineChainedAccessChain>(context);
116 }
117
118 //===----------------------------------------------------------------------===//
119 // spv.BitcastOp
120 //===----------------------------------------------------------------------===//
121
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)122 void spirv::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
123 MLIRContext *context) {
124 results.add<ConvertChainedBitcast>(context);
125 }
126
127 //===----------------------------------------------------------------------===//
128 // spv.CompositeExtractOp
129 //===----------------------------------------------------------------------===//
130
fold(ArrayRef<Attribute> operands)131 OpFoldResult spirv::CompositeExtractOp::fold(ArrayRef<Attribute> operands) {
132 assert(operands.size() == 1 && "spv.CompositeExtract expects one operand");
133 auto indexVector =
134 llvm::to_vector<8>(llvm::map_range(indices(), [](Attribute attr) {
135 return static_cast<unsigned>(attr.cast<IntegerAttr>().getInt());
136 }));
137 return extractCompositeElement(operands[0], indexVector);
138 }
139
140 //===----------------------------------------------------------------------===//
141 // spv.Constant
142 //===----------------------------------------------------------------------===//
143
fold(ArrayRef<Attribute> operands)144 OpFoldResult spirv::ConstantOp::fold(ArrayRef<Attribute> operands) {
145 assert(operands.empty() && "spv.Constant has no operands");
146 return value();
147 }
148
149 //===----------------------------------------------------------------------===//
150 // spv.IAdd
151 //===----------------------------------------------------------------------===//
152
fold(ArrayRef<Attribute> operands)153 OpFoldResult spirv::IAddOp::fold(ArrayRef<Attribute> operands) {
154 assert(operands.size() == 2 && "spv.IAdd expects two operands");
155 // x + 0 = x
156 if (matchPattern(operand2(), m_Zero()))
157 return operand1();
158
159 // According to the SPIR-V spec:
160 //
161 // The resulting value will equal the low-order N bits of the correct result
162 // R, where N is the component width and R is computed with enough precision
163 // to avoid overflow and underflow.
164 return constFoldBinaryOp<IntegerAttr>(operands,
165 [](APInt a, APInt b) { return a + b; });
166 }
167
168 //===----------------------------------------------------------------------===//
169 // spv.IMul
170 //===----------------------------------------------------------------------===//
171
fold(ArrayRef<Attribute> operands)172 OpFoldResult spirv::IMulOp::fold(ArrayRef<Attribute> operands) {
173 assert(operands.size() == 2 && "spv.IMul expects two operands");
174 // x * 0 == 0
175 if (matchPattern(operand2(), m_Zero()))
176 return operand2();
177 // x * 1 = x
178 if (matchPattern(operand2(), m_One()))
179 return operand1();
180
181 // According to the SPIR-V spec:
182 //
183 // The resulting value will equal the low-order N bits of the correct result
184 // R, where N is the component width and R is computed with enough precision
185 // to avoid overflow and underflow.
186 return constFoldBinaryOp<IntegerAttr>(operands,
187 [](APInt a, APInt b) { return a * b; });
188 }
189
190 //===----------------------------------------------------------------------===//
191 // spv.ISub
192 //===----------------------------------------------------------------------===//
193
fold(ArrayRef<Attribute> operands)194 OpFoldResult spirv::ISubOp::fold(ArrayRef<Attribute> operands) {
195 // x - x = 0
196 if (operand1() == operand2())
197 return Builder(getContext()).getIntegerAttr(getType(), 0);
198
199 // According to the SPIR-V spec:
200 //
201 // The resulting value will equal the low-order N bits of the correct result
202 // R, where N is the component width and R is computed with enough precision
203 // to avoid overflow and underflow.
204 return constFoldBinaryOp<IntegerAttr>(operands,
205 [](APInt a, APInt b) { return a - b; });
206 }
207
208 //===----------------------------------------------------------------------===//
209 // spv.LogicalAnd
210 //===----------------------------------------------------------------------===//
211
fold(ArrayRef<Attribute> operands)212 OpFoldResult spirv::LogicalAndOp::fold(ArrayRef<Attribute> operands) {
213 assert(operands.size() == 2 && "spv.LogicalAnd should take two operands");
214
215 if (Optional<bool> rhs = getScalarOrSplatBoolAttr(operands.back())) {
216 // x && true = x
217 if (rhs.getValue())
218 return operand1();
219
220 // x && false = false
221 if (!rhs.getValue())
222 return operands.back();
223 }
224
225 return Attribute();
226 }
227
228 //===----------------------------------------------------------------------===//
229 // spv.LogicalNot
230 //===----------------------------------------------------------------------===//
231
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)232 void spirv::LogicalNotOp::getCanonicalizationPatterns(
233 RewritePatternSet &results, MLIRContext *context) {
234 results
235 .add<ConvertLogicalNotOfIEqual, ConvertLogicalNotOfINotEqual,
236 ConvertLogicalNotOfLogicalEqual, ConvertLogicalNotOfLogicalNotEqual>(
237 context);
238 }
239
240 //===----------------------------------------------------------------------===//
241 // spv.LogicalOr
242 //===----------------------------------------------------------------------===//
243
fold(ArrayRef<Attribute> operands)244 OpFoldResult spirv::LogicalOrOp::fold(ArrayRef<Attribute> operands) {
245 assert(operands.size() == 2 && "spv.LogicalOr should take two operands");
246
247 if (auto rhs = getScalarOrSplatBoolAttr(operands.back())) {
248 if (rhs.getValue())
249 // x || true = true
250 return operands.back();
251
252 // x || false = x
253 if (!rhs.getValue())
254 return operand1();
255 }
256
257 return Attribute();
258 }
259
260 //===----------------------------------------------------------------------===//
261 // spv.mlir.selection
262 //===----------------------------------------------------------------------===//
263
264 namespace {
265 // Blocks from the given `spv.mlir.selection` operation must satisfy the
266 // following layout:
267 //
268 // +-----------------------------------------------+
269 // | header block |
270 // | spv.BranchConditionalOp %cond, ^case0, ^case1 |
271 // +-----------------------------------------------+
272 // / \
273 // ...
274 //
275 //
276 // +------------------------+ +------------------------+
277 // | case #0 | | case #1 |
278 // | spv.Store %ptr %value0 | | spv.Store %ptr %value1 |
279 // | spv.Branch ^merge | | spv.Branch ^merge |
280 // +------------------------+ +------------------------+
281 //
282 //
283 // ...
284 // \ /
285 // v
286 // +-------------+
287 // | merge block |
288 // +-------------+
289 //
290 struct ConvertSelectionOpToSelect
291 : public OpRewritePattern<spirv::SelectionOp> {
292 using OpRewritePattern<spirv::SelectionOp>::OpRewritePattern;
293
matchAndRewrite__anon8d29c6030711::ConvertSelectionOpToSelect294 LogicalResult matchAndRewrite(spirv::SelectionOp selectionOp,
295 PatternRewriter &rewriter) const override {
296 auto *op = selectionOp.getOperation();
297 auto &body = op->getRegion(0);
298 // Verifier allows an empty region for `spv.mlir.selection`.
299 if (body.empty()) {
300 return failure();
301 }
302
303 // Check that region consists of 4 blocks:
304 // header block, `true` block, `false` block and merge block.
305 if (std::distance(body.begin(), body.end()) != 4) {
306 return failure();
307 }
308
309 auto *headerBlock = selectionOp.getHeaderBlock();
310 if (!onlyContainsBranchConditionalOp(headerBlock)) {
311 return failure();
312 }
313
314 auto brConditionalOp =
315 cast<spirv::BranchConditionalOp>(headerBlock->front());
316
317 auto *trueBlock = brConditionalOp.getSuccessor(0);
318 auto *falseBlock = brConditionalOp.getSuccessor(1);
319 auto *mergeBlock = selectionOp.getMergeBlock();
320
321 if (failed(canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)))
322 return failure();
323
324 auto trueValue = getSrcValue(trueBlock);
325 auto falseValue = getSrcValue(falseBlock);
326 auto ptrValue = getDstPtr(trueBlock);
327 auto storeOpAttributes =
328 cast<spirv::StoreOp>(trueBlock->front())->getAttrs();
329
330 auto selectOp = rewriter.create<spirv::SelectOp>(
331 selectionOp.getLoc(), trueValue.getType(), brConditionalOp.condition(),
332 trueValue, falseValue);
333 rewriter.create<spirv::StoreOp>(selectOp.getLoc(), ptrValue,
334 selectOp.getResult(), storeOpAttributes);
335
336 // `spv.mlir.selection` is not needed anymore.
337 rewriter.eraseOp(op);
338 return success();
339 }
340
341 private:
342 // Checks that given blocks follow the following rules:
343 // 1. Each conditional block consists of two operations, the first operation
344 // is a `spv.Store` and the last operation is a `spv.Branch`.
345 // 2. Each `spv.Store` uses the same pointer and the same memory attributes.
346 // 3. A control flow goes into the given merge block from the given
347 // conditional blocks.
348 LogicalResult canCanonicalizeSelection(Block *trueBlock, Block *falseBlock,
349 Block *mergeBlock) const;
350
onlyContainsBranchConditionalOp__anon8d29c6030711::ConvertSelectionOpToSelect351 bool onlyContainsBranchConditionalOp(Block *block) const {
352 return std::next(block->begin()) == block->end() &&
353 isa<spirv::BranchConditionalOp>(block->front());
354 }
355
isSameAttrList__anon8d29c6030711::ConvertSelectionOpToSelect356 bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs) const {
357 return lhs->getAttrDictionary() == rhs->getAttrDictionary();
358 }
359
360 // Returns a source value for the given block.
getSrcValue__anon8d29c6030711::ConvertSelectionOpToSelect361 Value getSrcValue(Block *block) const {
362 auto storeOp = cast<spirv::StoreOp>(block->front());
363 return storeOp.value();
364 }
365
366 // Returns a destination value for the given block.
getDstPtr__anon8d29c6030711::ConvertSelectionOpToSelect367 Value getDstPtr(Block *block) const {
368 auto storeOp = cast<spirv::StoreOp>(block->front());
369 return storeOp.ptr();
370 }
371 };
372
canCanonicalizeSelection(Block * trueBlock,Block * falseBlock,Block * mergeBlock) const373 LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
374 Block *trueBlock, Block *falseBlock, Block *mergeBlock) const {
375 // Each block must consists of 2 operations.
376 if ((std::distance(trueBlock->begin(), trueBlock->end()) != 2) ||
377 (std::distance(falseBlock->begin(), falseBlock->end()) != 2)) {
378 return failure();
379 }
380
381 auto trueBrStoreOp = dyn_cast<spirv::StoreOp>(trueBlock->front());
382 auto trueBrBranchOp =
383 dyn_cast<spirv::BranchOp>(*std::next(trueBlock->begin()));
384 auto falseBrStoreOp = dyn_cast<spirv::StoreOp>(falseBlock->front());
385 auto falseBrBranchOp =
386 dyn_cast<spirv::BranchOp>(*std::next(falseBlock->begin()));
387
388 if (!trueBrStoreOp || !trueBrBranchOp || !falseBrStoreOp ||
389 !falseBrBranchOp) {
390 return failure();
391 }
392
393 // Checks that given type is valid for `spv.SelectOp`.
394 // According to SPIR-V spec:
395 // "Before version 1.4, Result Type must be a pointer, scalar, or vector.
396 // Starting with version 1.4, Result Type can additionally be a composite type
397 // other than a vector."
398 bool isScalarOrVector = trueBrStoreOp.value()
399 .getType()
400 .cast<spirv::SPIRVType>()
401 .isScalarOrVector();
402
403 // Check that each `spv.Store` uses the same pointer, memory access
404 // attributes and a valid type of the value.
405 if ((trueBrStoreOp.ptr() != falseBrStoreOp.ptr()) ||
406 !isSameAttrList(trueBrStoreOp, falseBrStoreOp) || !isScalarOrVector) {
407 return failure();
408 }
409
410 if ((trueBrBranchOp->getSuccessor(0) != mergeBlock) ||
411 (falseBrBranchOp->getSuccessor(0) != mergeBlock)) {
412 return failure();
413 }
414
415 return success();
416 }
417 } // end anonymous namespace
418
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)419 void spirv::SelectionOp::getCanonicalizationPatterns(RewritePatternSet &results,
420 MLIRContext *context) {
421 results.add<ConvertSelectionOpToSelect>(context);
422 }
423