1 //===- StandardToSPIRV.cpp - Standard to SPIR-V Patterns ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert standard dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
14 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
15 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
16 #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"
18 #include "mlir/IR/AffineMap.h"
19 #include "mlir/Support/LogicalResult.h"
20 #include "llvm/ADT/SetVector.h"
21 #include "llvm/Support/Debug.h"
22
23 #define DEBUG_TYPE "std-to-spirv-pattern"
24
25 using namespace mlir;
26
27 //===----------------------------------------------------------------------===//
28 // Utility functions
29 //===----------------------------------------------------------------------===//
30
31 /// Returns true if the given `type` is a boolean scalar or vector type.
isBoolScalarOrVector(Type type)32 static bool isBoolScalarOrVector(Type type) {
33 if (type.isInteger(1))
34 return true;
35 if (auto vecType = type.dyn_cast<VectorType>())
36 return vecType.getElementType().isInteger(1);
37 return false;
38 }
39
40 /// Converts the given `srcAttr` into a boolean attribute if it holds an
41 /// integral value. Returns null attribute if conversion fails.
convertBoolAttr(Attribute srcAttr,Builder builder)42 static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) {
43 if (auto boolAttr = srcAttr.dyn_cast<BoolAttr>())
44 return boolAttr;
45 if (auto intAttr = srcAttr.dyn_cast<IntegerAttr>())
46 return builder.getBoolAttr(intAttr.getValue().getBoolValue());
47 return BoolAttr();
48 }
49
50 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
51 /// Returns null attribute if conversion fails.
convertIntegerAttr(IntegerAttr srcAttr,IntegerType dstType,Builder builder)52 static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType,
53 Builder builder) {
54 // If the source number uses less active bits than the target bitwidth, then
55 // it should be safe to convert.
56 if (srcAttr.getValue().isIntN(dstType.getWidth()))
57 return builder.getIntegerAttr(dstType, srcAttr.getInt());
58
59 // XXX: Try again by interpreting the source number as a signed value.
60 // Although integers in the standard dialect are signless, they can represent
61 // a signed number. It's the operation decides how to interpret. This is
62 // dangerous, but it seems there is no good way of handling this if we still
63 // want to change the bitwidth. Emit a message at least.
64 if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
65 auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt());
66 LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '"
67 << dstAttr << "' for type '" << dstType << "'\n");
68 return dstAttr;
69 }
70
71 LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr
72 << "' illegal: cannot fit into target type '"
73 << dstType << "'\n");
74 return IntegerAttr();
75 }
76
77 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
78 /// Returns null attribute if `dstType` is not 32-bit or conversion fails.
convertFloatAttr(FloatAttr srcAttr,FloatType dstType,Builder builder)79 static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
80 Builder builder) {
81 // Only support converting to float for now.
82 if (!dstType.isF32())
83 return FloatAttr();
84
85 // Try to convert the source floating-point number to single precision.
86 APFloat dstVal = srcAttr.getValue();
87 bool losesInfo = false;
88 APFloat::opStatus status =
89 dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
90 if (status != APFloat::opOK || losesInfo) {
91 LLVM_DEBUG(llvm::dbgs()
92 << srcAttr << " illegal: cannot fit into converted type '"
93 << dstType << "'\n");
94 return FloatAttr();
95 }
96
97 return builder.getF32FloatAttr(dstVal.convertToFloat());
98 }
99
100 /// Returns signed remainder for `lhs` and `rhs` and lets the result follow
101 /// the sign of `signOperand`.
102 ///
103 /// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment
104 /// spec, "for the OpSRem and OpSMod instructions, if either operand is negative
105 /// the result is undefined." So we cannot directly use spv.SRem/spv.SMod
106 /// if either operand can be negative. Emulate it via spv.UMod.
emulateSignedRemainder(Location loc,Value lhs,Value rhs,Value signOperand,OpBuilder & builder)107 static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
108 Value signOperand, OpBuilder &builder) {
109 assert(lhs.getType() == rhs.getType());
110 assert(lhs == signOperand || rhs == signOperand);
111
112 Type type = lhs.getType();
113
114 // Calculate the remainder with spv.UMod.
115 Value lhsAbs = builder.create<spirv::GLSLSAbsOp>(loc, type, lhs);
116 Value rhsAbs = builder.create<spirv::GLSLSAbsOp>(loc, type, rhs);
117 Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs);
118
119 // Fix the sign.
120 Value isPositive;
121 if (lhs == signOperand)
122 isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs);
123 else
124 isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs);
125 Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs);
126 return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate);
127 }
128
129 /// Returns the offset of the value in `targetBits` representation.
130 ///
131 /// `srcIdx` is an index into a 1-D array with each element having `sourceBits`.
132 /// It's assumed to be non-negative.
133 ///
134 /// When accessing an element in the array treating as having elements of
135 /// `targetBits`, multiple values are loaded in the same time. The method
136 /// returns the offset where the `srcIdx` locates in the value. For example, if
137 /// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is
138 /// located at (x % 4) * 8. Because there are four elements in one i32, and one
139 /// element has 8 bits.
getOffsetForBitwidth(Location loc,Value srcIdx,int sourceBits,int targetBits,OpBuilder & builder)140 static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
141 int targetBits, OpBuilder &builder) {
142 assert(targetBits % sourceBits == 0);
143 IntegerType targetType = builder.getIntegerType(targetBits);
144 IntegerAttr idxAttr =
145 builder.getIntegerAttr(targetType, targetBits / sourceBits);
146 auto idx = builder.create<spirv::ConstantOp>(loc, targetType, idxAttr);
147 IntegerAttr srcBitsAttr = builder.getIntegerAttr(targetType, sourceBits);
148 auto srcBitsValue =
149 builder.create<spirv::ConstantOp>(loc, targetType, srcBitsAttr);
150 auto m = builder.create<spirv::UModOp>(loc, srcIdx, idx);
151 return builder.create<spirv::IMulOp>(loc, targetType, m, srcBitsValue);
152 }
153
154 /// Returns an adjusted spirv::AccessChainOp. Based on the
155 /// extension/capabilities, certain integer bitwidths `sourceBits` might not be
156 /// supported. During conversion if a memref of an unsupported type is used,
157 /// load/stores to this memref need to be modified to use a supported higher
158 /// bitwidth `targetBits` and extracting the required bits. For an accessing a
159 /// 1D array (spv.array or spv.rt_array), the last index is modified to load the
160 /// bits needed. The extraction of the actual bits needed are handled
161 /// separately. Note that this only works for a 1-D tensor.
adjustAccessChainForBitwidth(SPIRVTypeConverter & typeConverter,spirv::AccessChainOp op,int sourceBits,int targetBits,OpBuilder & builder)162 static Value adjustAccessChainForBitwidth(SPIRVTypeConverter &typeConverter,
163 spirv::AccessChainOp op,
164 int sourceBits, int targetBits,
165 OpBuilder &builder) {
166 assert(targetBits % sourceBits == 0);
167 const auto loc = op.getLoc();
168 IntegerType targetType = builder.getIntegerType(targetBits);
169 IntegerAttr attr =
170 builder.getIntegerAttr(targetType, targetBits / sourceBits);
171 auto idx = builder.create<spirv::ConstantOp>(loc, targetType, attr);
172 auto lastDim = op->getOperand(op.getNumOperands() - 1);
173 auto indices = llvm::to_vector<4>(op.indices());
174 // There are two elements if this is a 1-D tensor.
175 assert(indices.size() == 2);
176 indices.back() = builder.create<spirv::SDivOp>(loc, lastDim, idx);
177 Type t = typeConverter.convertType(op.component_ptr().getType());
178 return builder.create<spirv::AccessChainOp>(loc, t, op.base_ptr(), indices);
179 }
180
181 /// Returns the shifted `targetBits`-bit value with the given offset.
shiftValue(Location loc,Value value,Value offset,Value mask,int targetBits,OpBuilder & builder)182 static Value shiftValue(Location loc, Value value, Value offset, Value mask,
183 int targetBits, OpBuilder &builder) {
184 Type targetType = builder.getIntegerType(targetBits);
185 Value result = builder.create<spirv::BitwiseAndOp>(loc, value, mask);
186 return builder.create<spirv::ShiftLeftLogicalOp>(loc, targetType, result,
187 offset);
188 }
189
190 /// Returns true if the allocations of type `t` can be lowered to SPIR-V.
isAllocationSupported(MemRefType t)191 static bool isAllocationSupported(MemRefType t) {
192 // Currently only support workgroup local memory allocations with static
193 // shape and int or float or vector of int or float element type.
194 if (!(t.hasStaticShape() &&
195 SPIRVTypeConverter::getMemorySpaceForStorageClass(
196 spirv::StorageClass::Workgroup) == t.getMemorySpace()))
197 return false;
198 Type elementType = t.getElementType();
199 if (auto vecType = elementType.dyn_cast<VectorType>())
200 elementType = vecType.getElementType();
201 return elementType.isIntOrFloat();
202 }
203
204 /// Returns the scope to use for atomic operations use for emulating store
205 /// operations of unsupported integer bitwidths, based on the memref
206 /// type. Returns None on failure.
getAtomicOpScope(MemRefType t)207 static Optional<spirv::Scope> getAtomicOpScope(MemRefType t) {
208 Optional<spirv::StorageClass> storageClass =
209 SPIRVTypeConverter::getStorageClassForMemorySpace(t.getMemorySpace());
210 if (!storageClass)
211 return {};
212 switch (*storageClass) {
213 case spirv::StorageClass::StorageBuffer:
214 return spirv::Scope::Device;
215 case spirv::StorageClass::Workgroup:
216 return spirv::Scope::Workgroup;
217 default: {
218 }
219 }
220 return {};
221 }
222
223 //===----------------------------------------------------------------------===//
224 // Operation conversion
225 //===----------------------------------------------------------------------===//
226
227 // Note that DRR cannot be used for the patterns in this file: we may need to
228 // convert type along the way, which requires ConversionPattern. DRR generates
229 // normal RewritePattern.
230
231 namespace {
232
233 /// Converts an allocation operation to SPIR-V. Currently only supports lowering
234 /// to Workgroup memory when the size is constant. Note that this pattern needs
235 /// to be applied in a pass that runs at least at spv.module scope since it wil
236 /// ladd global variables into the spv.module.
237 class AllocOpPattern final : public OpConversionPattern<AllocOp> {
238 public:
239 using OpConversionPattern<AllocOp>::OpConversionPattern;
240
241 LogicalResult
matchAndRewrite(AllocOp operation,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const242 matchAndRewrite(AllocOp operation, ArrayRef<Value> operands,
243 ConversionPatternRewriter &rewriter) const override {
244 MemRefType allocType = operation.getType();
245 if (!isAllocationSupported(allocType))
246 return operation.emitError("unhandled allocation type");
247
248 // Get the SPIR-V type for the allocation.
249 Type spirvType = getTypeConverter()->convertType(allocType);
250
251 // Insert spv.globalVariable for this allocation.
252 Operation *parent =
253 SymbolTable::getNearestSymbolTable(operation->getParentOp());
254 if (!parent)
255 return failure();
256 Location loc = operation.getLoc();
257 spirv::GlobalVariableOp varOp;
258 {
259 OpBuilder::InsertionGuard guard(rewriter);
260 Block &entryBlock = *parent->getRegion(0).begin();
261 rewriter.setInsertionPointToStart(&entryBlock);
262 auto varOps = entryBlock.getOps<spirv::GlobalVariableOp>();
263 std::string varName =
264 std::string("__workgroup_mem__") +
265 std::to_string(std::distance(varOps.begin(), varOps.end()));
266 varOp = rewriter.create<spirv::GlobalVariableOp>(
267 loc, TypeAttr::get(spirvType), varName,
268 /*initializer = */ nullptr);
269 }
270
271 // Get pointer to global variable at the current scope.
272 rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp);
273 return success();
274 }
275 };
276
277 /// Removed a deallocation if it is a supported allocation. Currently only
278 /// removes deallocation if the memory space is workgroup memory.
279 class DeallocOpPattern final : public OpConversionPattern<DeallocOp> {
280 public:
281 using OpConversionPattern<DeallocOp>::OpConversionPattern;
282
283 LogicalResult
matchAndRewrite(DeallocOp operation,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const284 matchAndRewrite(DeallocOp operation, ArrayRef<Value> operands,
285 ConversionPatternRewriter &rewriter) const override {
286 MemRefType deallocType = operation.memref().getType().cast<MemRefType>();
287 if (!isAllocationSupported(deallocType))
288 return operation.emitError("unhandled deallocation type");
289 rewriter.eraseOp(operation);
290 return success();
291 }
292 };
293
294 /// Converts unary and binary standard operations to SPIR-V operations.
295 template <typename StdOp, typename SPIRVOp>
296 class UnaryAndBinaryOpPattern final : public OpConversionPattern<StdOp> {
297 public:
298 using OpConversionPattern<StdOp>::OpConversionPattern;
299
300 LogicalResult
matchAndRewrite(StdOp operation,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const301 matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
302 ConversionPatternRewriter &rewriter) const override {
303 assert(operands.size() <= 2);
304 auto dstType = this->getTypeConverter()->convertType(operation.getType());
305 if (!dstType)
306 return failure();
307 if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
308 dstType != operation.getType()) {
309 return operation.emitError(
310 "bitwidth emulation is not implemented yet on unsigned op");
311 }
312 rewriter.template replaceOpWithNewOp<SPIRVOp>(operation, dstType, operands);
313 return success();
314 }
315 };
316
317 /// Converts std.remi_signed to SPIR-V ops.
318 ///
319 /// This cannot be merged into the template unary/binary pattern due to
320 /// Vulkan restrictions over spv.SRem and spv.SMod.
321 class SignedRemIOpPattern final : public OpConversionPattern<SignedRemIOp> {
322 public:
323 using OpConversionPattern<SignedRemIOp>::OpConversionPattern;
324
325 LogicalResult
326 matchAndRewrite(SignedRemIOp remOp, ArrayRef<Value> operands,
327 ConversionPatternRewriter &rewriter) const override;
328 };
329
330 /// Converts bitwise standard operations to SPIR-V operations. This is a special
331 /// pattern other than the BinaryOpPatternPattern because if the operands are
332 /// boolean values, SPIR-V uses different operations (`SPIRVLogicalOp`). For
333 /// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`.
334 template <typename StdOp, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
335 class BitwiseOpPattern final : public OpConversionPattern<StdOp> {
336 public:
337 using OpConversionPattern<StdOp>::OpConversionPattern;
338
339 LogicalResult
matchAndRewrite(StdOp operation,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const340 matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
341 ConversionPatternRewriter &rewriter) const override {
342 assert(operands.size() == 2);
343 auto dstType =
344 this->getTypeConverter()->convertType(operation.getResult().getType());
345 if (!dstType)
346 return failure();
347 if (isBoolScalarOrVector(operands.front().getType())) {
348 rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(operation, dstType,
349 operands);
350 } else {
351 rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(operation, dstType,
352 operands);
353 }
354 return success();
355 }
356 };
357
358 /// Converts composite std.constant operation to spv.constant.
359 class ConstantCompositeOpPattern final
360 : public OpConversionPattern<ConstantOp> {
361 public:
362 using OpConversionPattern<ConstantOp>::OpConversionPattern;
363
364 LogicalResult
365 matchAndRewrite(ConstantOp constOp, ArrayRef<Value> operands,
366 ConversionPatternRewriter &rewriter) const override;
367 };
368
369 /// Converts scalar std.constant operation to spv.constant.
370 class ConstantScalarOpPattern final : public OpConversionPattern<ConstantOp> {
371 public:
372 using OpConversionPattern<ConstantOp>::OpConversionPattern;
373
374 LogicalResult
375 matchAndRewrite(ConstantOp constOp, ArrayRef<Value> operands,
376 ConversionPatternRewriter &rewriter) const override;
377 };
378
379 /// Converts floating-point comparison operations to SPIR-V ops.
380 class CmpFOpPattern final : public OpConversionPattern<CmpFOp> {
381 public:
382 using OpConversionPattern<CmpFOp>::OpConversionPattern;
383
384 LogicalResult
385 matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
386 ConversionPatternRewriter &rewriter) const override;
387 };
388
389 /// Converts floating point NaN check to SPIR-V ops. This pattern requires
390 /// Kernel capability.
391 class CmpFOpNanKernelPattern final : public OpConversionPattern<CmpFOp> {
392 public:
393 using OpConversionPattern<CmpFOp>::OpConversionPattern;
394
395 LogicalResult
396 matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
397 ConversionPatternRewriter &rewriter) const override;
398 };
399
400 /// Converts floating point NaN check to SPIR-V ops. This pattern does not
401 /// require additional capability.
402 class CmpFOpNanNonePattern final : public OpConversionPattern<CmpFOp> {
403 public:
404 using OpConversionPattern<CmpFOp>::OpConversionPattern;
405
406 LogicalResult
407 matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
408 ConversionPatternRewriter &rewriter) const override;
409 };
410
411 /// Converts integer compare operation on i1 type operands to SPIR-V ops.
412 class BoolCmpIOpPattern final : public OpConversionPattern<CmpIOp> {
413 public:
414 using OpConversionPattern<CmpIOp>::OpConversionPattern;
415
416 LogicalResult
417 matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
418 ConversionPatternRewriter &rewriter) const override;
419 };
420
421 /// Converts integer compare operation to SPIR-V ops.
422 class CmpIOpPattern final : public OpConversionPattern<CmpIOp> {
423 public:
424 using OpConversionPattern<CmpIOp>::OpConversionPattern;
425
426 LogicalResult
427 matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
428 ConversionPatternRewriter &rewriter) const override;
429 };
430
431 /// Converts std.load to spv.Load.
432 class IntLoadOpPattern final : public OpConversionPattern<LoadOp> {
433 public:
434 using OpConversionPattern<LoadOp>::OpConversionPattern;
435
436 LogicalResult
437 matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
438 ConversionPatternRewriter &rewriter) const override;
439 };
440
441 /// Converts std.load to spv.Load.
442 class LoadOpPattern final : public OpConversionPattern<LoadOp> {
443 public:
444 using OpConversionPattern<LoadOp>::OpConversionPattern;
445
446 LogicalResult
447 matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
448 ConversionPatternRewriter &rewriter) const override;
449 };
450
451 /// Converts std.return to spv.Return.
452 class ReturnOpPattern final : public OpConversionPattern<ReturnOp> {
453 public:
454 using OpConversionPattern<ReturnOp>::OpConversionPattern;
455
456 LogicalResult
457 matchAndRewrite(ReturnOp returnOp, ArrayRef<Value> operands,
458 ConversionPatternRewriter &rewriter) const override;
459 };
460
461 /// Converts std.select to spv.Select.
462 class SelectOpPattern final : public OpConversionPattern<SelectOp> {
463 public:
464 using OpConversionPattern<SelectOp>::OpConversionPattern;
465 LogicalResult
466 matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
467 ConversionPatternRewriter &rewriter) const override;
468 };
469
470 /// Converts std.store to spv.Store on integers.
471 class IntStoreOpPattern final : public OpConversionPattern<StoreOp> {
472 public:
473 using OpConversionPattern<StoreOp>::OpConversionPattern;
474
475 LogicalResult
476 matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
477 ConversionPatternRewriter &rewriter) const override;
478 };
479
480 /// Converts std.store to spv.Store.
481 class StoreOpPattern final : public OpConversionPattern<StoreOp> {
482 public:
483 using OpConversionPattern<StoreOp>::OpConversionPattern;
484
485 LogicalResult
486 matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
487 ConversionPatternRewriter &rewriter) const override;
488 };
489
490 /// Converts std.zexti to spv.Select if the type of source is i1 or vector of
491 /// i1.
492 class ZeroExtendI1Pattern final : public OpConversionPattern<ZeroExtendIOp> {
493 public:
494 using OpConversionPattern<ZeroExtendIOp>::OpConversionPattern;
495
496 LogicalResult
matchAndRewrite(ZeroExtendIOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const497 matchAndRewrite(ZeroExtendIOp op, ArrayRef<Value> operands,
498 ConversionPatternRewriter &rewriter) const override {
499 auto srcType = operands.front().getType();
500 if (!isBoolScalarOrVector(srcType))
501 return failure();
502
503 auto dstType =
504 this->getTypeConverter()->convertType(op.getResult().getType());
505 Location loc = op.getLoc();
506 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
507 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
508 rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
509 op, dstType, operands.front(), one, zero);
510 return success();
511 }
512 };
513
514 /// Converts std.uitofp to spv.Select if the type of source is i1 or vector of
515 /// i1.
516 class UIToFPI1Pattern final : public OpConversionPattern<UIToFPOp> {
517 public:
518 using OpConversionPattern<UIToFPOp>::OpConversionPattern;
519
520 LogicalResult
matchAndRewrite(UIToFPOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const521 matchAndRewrite(UIToFPOp op, ArrayRef<Value> operands,
522 ConversionPatternRewriter &rewriter) const override {
523 auto srcType = operands.front().getType();
524 if (!isBoolScalarOrVector(srcType))
525 return failure();
526
527 auto dstType =
528 this->getTypeConverter()->convertType(op.getResult().getType());
529 Location loc = op.getLoc();
530 Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
531 Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
532 rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
533 op, dstType, operands.front(), one, zero);
534 return success();
535 }
536 };
537
538 /// Converts type-casting standard operations to SPIR-V operations.
539 template <typename StdOp, typename SPIRVOp>
540 class TypeCastingOpPattern final : public OpConversionPattern<StdOp> {
541 public:
542 using OpConversionPattern<StdOp>::OpConversionPattern;
543
544 LogicalResult
matchAndRewrite(StdOp operation,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const545 matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
546 ConversionPatternRewriter &rewriter) const override {
547 assert(operands.size() == 1);
548 auto srcType = operands.front().getType();
549 if (isBoolScalarOrVector(srcType))
550 return failure();
551 auto dstType =
552 this->getTypeConverter()->convertType(operation.getResult().getType());
553 if (dstType == srcType) {
554 // Due to type conversion, we are seeing the same source and target type.
555 // Then we can just erase this operation by forwarding its operand.
556 rewriter.replaceOp(operation, operands.front());
557 } else {
558 rewriter.template replaceOpWithNewOp<SPIRVOp>(operation, dstType,
559 operands);
560 }
561 return success();
562 }
563 };
564
565 /// Converts std.xor to SPIR-V operations.
566 class XOrOpPattern final : public OpConversionPattern<XOrOp> {
567 public:
568 using OpConversionPattern<XOrOp>::OpConversionPattern;
569
570 LogicalResult
571 matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands,
572 ConversionPatternRewriter &rewriter) const override;
573 };
574
575 } // namespace
576
577 //===----------------------------------------------------------------------===//
578 // SignedRemIOpPattern
579 //===----------------------------------------------------------------------===//
580
matchAndRewrite(SignedRemIOp remOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const581 LogicalResult SignedRemIOpPattern::matchAndRewrite(
582 SignedRemIOp remOp, ArrayRef<Value> operands,
583 ConversionPatternRewriter &rewriter) const {
584 Value result = emulateSignedRemainder(remOp.getLoc(), operands[0],
585 operands[1], operands[0], rewriter);
586 rewriter.replaceOp(remOp, result);
587
588 return success();
589 }
590
591 //===----------------------------------------------------------------------===//
592 // ConstantOp with composite type.
593 //===----------------------------------------------------------------------===//
594
matchAndRewrite(ConstantOp constOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const595 LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
596 ConstantOp constOp, ArrayRef<Value> operands,
597 ConversionPatternRewriter &rewriter) const {
598 auto srcType = constOp.getType().dyn_cast<ShapedType>();
599 if (!srcType)
600 return failure();
601
602 // std.constant should only have vector or tenor types.
603 assert((srcType.isa<VectorType, RankedTensorType>()));
604
605 auto dstType = getTypeConverter()->convertType(srcType);
606 if (!dstType)
607 return failure();
608
609 auto dstElementsAttr = constOp.value().dyn_cast<DenseElementsAttr>();
610 ShapedType dstAttrType = dstElementsAttr.getType();
611 if (!dstElementsAttr)
612 return failure();
613
614 // If the composite type has more than one dimensions, perform linearization.
615 if (srcType.getRank() > 1) {
616 if (srcType.isa<RankedTensorType>()) {
617 dstAttrType = RankedTensorType::get(srcType.getNumElements(),
618 srcType.getElementType());
619 dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
620 } else {
621 // TODO: add support for large vectors.
622 return failure();
623 }
624 }
625
626 Type srcElemType = srcType.getElementType();
627 Type dstElemType;
628 // Tensor types are converted to SPIR-V array types; vector types are
629 // converted to SPIR-V vector/array types.
630 if (auto arrayType = dstType.dyn_cast<spirv::ArrayType>())
631 dstElemType = arrayType.getElementType();
632 else
633 dstElemType = dstType.cast<VectorType>().getElementType();
634
635 // If the source and destination element types are different, perform
636 // attribute conversion.
637 if (srcElemType != dstElemType) {
638 SmallVector<Attribute, 8> elements;
639 if (srcElemType.isa<FloatType>()) {
640 for (Attribute srcAttr : dstElementsAttr.getAttributeValues()) {
641 FloatAttr dstAttr = convertFloatAttr(
642 srcAttr.cast<FloatAttr>(), dstElemType.cast<FloatType>(), rewriter);
643 if (!dstAttr)
644 return failure();
645 elements.push_back(dstAttr);
646 }
647 } else if (srcElemType.isInteger(1)) {
648 return failure();
649 } else {
650 for (Attribute srcAttr : dstElementsAttr.getAttributeValues()) {
651 IntegerAttr dstAttr =
652 convertIntegerAttr(srcAttr.cast<IntegerAttr>(),
653 dstElemType.cast<IntegerType>(), rewriter);
654 if (!dstAttr)
655 return failure();
656 elements.push_back(dstAttr);
657 }
658 }
659
660 // Unfortunately, we cannot use dialect-specific types for element
661 // attributes; element attributes only works with builtin types. So we need
662 // to prepare another converted builtin types for the destination elements
663 // attribute.
664 if (dstAttrType.isa<RankedTensorType>())
665 dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType);
666 else
667 dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
668
669 dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
670 }
671
672 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
673 dstElementsAttr);
674 return success();
675 }
676
677 //===----------------------------------------------------------------------===//
678 // ConstantOp with scalar type.
679 //===----------------------------------------------------------------------===//
680
matchAndRewrite(ConstantOp constOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const681 LogicalResult ConstantScalarOpPattern::matchAndRewrite(
682 ConstantOp constOp, ArrayRef<Value> operands,
683 ConversionPatternRewriter &rewriter) const {
684 Type srcType = constOp.getType();
685 if (!srcType.isIntOrIndexOrFloat())
686 return failure();
687
688 Type dstType = getTypeConverter()->convertType(srcType);
689 if (!dstType)
690 return failure();
691
692 // Floating-point types.
693 if (srcType.isa<FloatType>()) {
694 auto srcAttr = constOp.value().cast<FloatAttr>();
695 auto dstAttr = srcAttr;
696
697 // Floating-point types not supported in the target environment are all
698 // converted to float type.
699 if (srcType != dstType) {
700 dstAttr = convertFloatAttr(srcAttr, dstType.cast<FloatType>(), rewriter);
701 if (!dstAttr)
702 return failure();
703 }
704
705 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
706 return success();
707 }
708
709 // Bool type.
710 if (srcType.isInteger(1)) {
711 // std.constant can use 0/1 instead of true/false for i1 values. We need to
712 // handle that here.
713 auto dstAttr = convertBoolAttr(constOp.value(), rewriter);
714 if (!dstAttr)
715 return failure();
716 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
717 return success();
718 }
719
720 // IndexType or IntegerType. Index values are converted to 32-bit integer
721 // values when converting to SPIR-V.
722 auto srcAttr = constOp.value().cast<IntegerAttr>();
723 auto dstAttr =
724 convertIntegerAttr(srcAttr, dstType.cast<IntegerType>(), rewriter);
725 if (!dstAttr)
726 return failure();
727 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
728 return success();
729 }
730
731 //===----------------------------------------------------------------------===//
732 // CmpFOp
733 //===----------------------------------------------------------------------===//
734
735 LogicalResult
matchAndRewrite(CmpFOp cmpFOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const736 CmpFOpPattern::matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
737 ConversionPatternRewriter &rewriter) const {
738 CmpFOpAdaptor cmpFOpOperands(operands);
739
740 switch (cmpFOp.getPredicate()) {
741 #define DISPATCH(cmpPredicate, spirvOp) \
742 case cmpPredicate: \
743 rewriter.replaceOpWithNewOp<spirvOp>(cmpFOp, cmpFOp.getResult().getType(), \
744 cmpFOpOperands.lhs(), \
745 cmpFOpOperands.rhs()); \
746 return success();
747
748 // Ordered.
749 DISPATCH(CmpFPredicate::OEQ, spirv::FOrdEqualOp);
750 DISPATCH(CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
751 DISPATCH(CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
752 DISPATCH(CmpFPredicate::OLT, spirv::FOrdLessThanOp);
753 DISPATCH(CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
754 DISPATCH(CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
755 // Unordered.
756 DISPATCH(CmpFPredicate::UEQ, spirv::FUnordEqualOp);
757 DISPATCH(CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
758 DISPATCH(CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
759 DISPATCH(CmpFPredicate::ULT, spirv::FUnordLessThanOp);
760 DISPATCH(CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
761 DISPATCH(CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
762
763 #undef DISPATCH
764
765 default:
766 break;
767 }
768 return failure();
769 }
770
matchAndRewrite(CmpFOp cmpFOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const771 LogicalResult CmpFOpNanKernelPattern::matchAndRewrite(
772 CmpFOp cmpFOp, ArrayRef<Value> operands,
773 ConversionPatternRewriter &rewriter) const {
774 CmpFOpAdaptor cmpFOpOperands(operands);
775
776 if (cmpFOp.getPredicate() == CmpFPredicate::ORD) {
777 rewriter.replaceOpWithNewOp<spirv::OrderedOp>(cmpFOp, cmpFOpOperands.lhs(),
778 cmpFOpOperands.rhs());
779 return success();
780 }
781
782 if (cmpFOp.getPredicate() == CmpFPredicate::UNO) {
783 rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(
784 cmpFOp, cmpFOpOperands.lhs(), cmpFOpOperands.rhs());
785 return success();
786 }
787
788 return failure();
789 }
790
matchAndRewrite(CmpFOp cmpFOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const791 LogicalResult CmpFOpNanNonePattern::matchAndRewrite(
792 CmpFOp cmpFOp, ArrayRef<Value> operands,
793 ConversionPatternRewriter &rewriter) const {
794 if (cmpFOp.getPredicate() != CmpFPredicate::ORD &&
795 cmpFOp.getPredicate() != CmpFPredicate::UNO)
796 return failure();
797
798 CmpFOpAdaptor cmpFOpOperands(operands);
799 Location loc = cmpFOp.getLoc();
800
801 Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, cmpFOpOperands.lhs());
802 Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, cmpFOpOperands.rhs());
803
804 Value replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan);
805 if (cmpFOp.getPredicate() == CmpFPredicate::ORD)
806 replace = rewriter.create<spirv::LogicalNotOp>(loc, replace);
807
808 rewriter.replaceOp(cmpFOp, replace);
809 return success();
810 }
811
812 //===----------------------------------------------------------------------===//
813 // CmpIOp
814 //===----------------------------------------------------------------------===//
815
816 LogicalResult
matchAndRewrite(CmpIOp cmpIOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const817 BoolCmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
818 ConversionPatternRewriter &rewriter) const {
819 CmpIOpAdaptor cmpIOpOperands(operands);
820
821 Type operandType = cmpIOp.lhs().getType();
822 if (!isBoolScalarOrVector(operandType))
823 return failure();
824
825 switch (cmpIOp.getPredicate()) {
826 #define DISPATCH(cmpPredicate, spirvOp) \
827 case cmpPredicate: \
828 rewriter.replaceOpWithNewOp<spirvOp>(cmpIOp, cmpIOp.getResult().getType(), \
829 cmpIOpOperands.lhs(), \
830 cmpIOpOperands.rhs()); \
831 return success();
832
833 DISPATCH(CmpIPredicate::eq, spirv::LogicalEqualOp);
834 DISPATCH(CmpIPredicate::ne, spirv::LogicalNotEqualOp);
835
836 #undef DISPATCH
837 default:;
838 }
839 return failure();
840 }
841
842 LogicalResult
matchAndRewrite(CmpIOp cmpIOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const843 CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
844 ConversionPatternRewriter &rewriter) const {
845 CmpIOpAdaptor cmpIOpOperands(operands);
846
847 Type operandType = cmpIOp.lhs().getType();
848 if (isBoolScalarOrVector(operandType))
849 return failure();
850
851 switch (cmpIOp.getPredicate()) {
852 #define DISPATCH(cmpPredicate, spirvOp) \
853 case cmpPredicate: \
854 if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \
855 operandType != this->getTypeConverter()->convertType(operandType)) { \
856 return cmpIOp.emitError( \
857 "bitwidth emulation is not implemented yet on unsigned op"); \
858 } \
859 rewriter.replaceOpWithNewOp<spirvOp>(cmpIOp, cmpIOp.getResult().getType(), \
860 cmpIOpOperands.lhs(), \
861 cmpIOpOperands.rhs()); \
862 return success();
863
864 DISPATCH(CmpIPredicate::eq, spirv::IEqualOp);
865 DISPATCH(CmpIPredicate::ne, spirv::INotEqualOp);
866 DISPATCH(CmpIPredicate::slt, spirv::SLessThanOp);
867 DISPATCH(CmpIPredicate::sle, spirv::SLessThanEqualOp);
868 DISPATCH(CmpIPredicate::sgt, spirv::SGreaterThanOp);
869 DISPATCH(CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
870 DISPATCH(CmpIPredicate::ult, spirv::ULessThanOp);
871 DISPATCH(CmpIPredicate::ule, spirv::ULessThanEqualOp);
872 DISPATCH(CmpIPredicate::ugt, spirv::UGreaterThanOp);
873 DISPATCH(CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
874
875 #undef DISPATCH
876 }
877 return failure();
878 }
879
880 //===----------------------------------------------------------------------===//
881 // LoadOp
882 //===----------------------------------------------------------------------===//
883
884 LogicalResult
matchAndRewrite(LoadOp loadOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const885 IntLoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
886 ConversionPatternRewriter &rewriter) const {
887 LoadOpAdaptor loadOperands(operands);
888 auto loc = loadOp.getLoc();
889 auto memrefType = loadOp.memref().getType().cast<MemRefType>();
890 if (!memrefType.getElementType().isSignlessInteger())
891 return failure();
892
893 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
894 spirv::AccessChainOp accessChainOp =
895 spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(),
896 loadOperands.indices(), loc, rewriter);
897
898 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
899 auto dstType = typeConverter.convertType(memrefType)
900 .cast<spirv::PointerType>()
901 .getPointeeType()
902 .cast<spirv::StructType>()
903 .getElementType(0)
904 .cast<spirv::ArrayType>()
905 .getElementType();
906 int dstBits = dstType.getIntOrFloatBitWidth();
907 assert(dstBits % srcBits == 0);
908
909 // If the rewrited load op has the same bit width, use the loading value
910 // directly.
911 if (srcBits == dstBits) {
912 rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp,
913 accessChainOp.getResult());
914 return success();
915 }
916
917 // Assume that getElementPtr() works linearizely. If it's a scalar, the method
918 // still returns a linearized accessing. If the accessing is not linearized,
919 // there will be offset issues.
920 assert(accessChainOp.indices().size() == 2);
921 Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
922 srcBits, dstBits, rewriter);
923 Value spvLoadOp = rewriter.create<spirv::LoadOp>(
924 loc, dstType, adjustedPtr,
925 loadOp->getAttrOfType<IntegerAttr>(
926 spirv::attributeName<spirv::MemoryAccess>()),
927 loadOp->getAttrOfType<IntegerAttr>("alignment"));
928
929 // Shift the bits to the rightmost.
930 // ____XXXX________ -> ____________XXXX
931 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
932 Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
933 Value result = rewriter.create<spirv::ShiftRightArithmeticOp>(
934 loc, spvLoadOp.getType(), spvLoadOp, offset);
935
936 // Apply the mask to extract corresponding bits.
937 Value mask = rewriter.create<spirv::ConstantOp>(
938 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
939 result = rewriter.create<spirv::BitwiseAndOp>(loc, dstType, result, mask);
940
941 // Apply sign extension on the loading value unconditionally. The signedness
942 // semantic is carried in the operator itself, we relies other pattern to
943 // handle the casting.
944 IntegerAttr shiftValueAttr =
945 rewriter.getIntegerAttr(dstType, dstBits - srcBits);
946 Value shiftValue =
947 rewriter.create<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
948 result = rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, result,
949 shiftValue);
950 result = rewriter.create<spirv::ShiftRightArithmeticOp>(loc, dstType, result,
951 shiftValue);
952 rewriter.replaceOp(loadOp, result);
953
954 assert(accessChainOp.use_empty());
955 rewriter.eraseOp(accessChainOp);
956
957 return success();
958 }
959
960 LogicalResult
matchAndRewrite(LoadOp loadOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const961 LoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
962 ConversionPatternRewriter &rewriter) const {
963 LoadOpAdaptor loadOperands(operands);
964 auto memrefType = loadOp.memref().getType().cast<MemRefType>();
965 if (memrefType.getElementType().isSignlessInteger())
966 return failure();
967 auto loadPtr = spirv::getElementPtr(
968 *getTypeConverter<SPIRVTypeConverter>(), memrefType,
969 loadOperands.memref(), loadOperands.indices(), loadOp.getLoc(), rewriter);
970 rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr);
971 return success();
972 }
973
974 //===----------------------------------------------------------------------===//
975 // ReturnOp
976 //===----------------------------------------------------------------------===//
977
978 LogicalResult
matchAndRewrite(ReturnOp returnOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const979 ReturnOpPattern::matchAndRewrite(ReturnOp returnOp, ArrayRef<Value> operands,
980 ConversionPatternRewriter &rewriter) const {
981 if (returnOp.getNumOperands() > 1)
982 return failure();
983
984 if (returnOp.getNumOperands() == 1) {
985 rewriter.replaceOpWithNewOp<spirv::ReturnValueOp>(returnOp, operands[0]);
986 } else {
987 rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
988 }
989 return success();
990 }
991
992 //===----------------------------------------------------------------------===//
993 // SelectOp
994 //===----------------------------------------------------------------------===//
995
996 LogicalResult
matchAndRewrite(SelectOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const997 SelectOpPattern::matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
998 ConversionPatternRewriter &rewriter) const {
999 SelectOpAdaptor selectOperands(operands);
1000 rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, selectOperands.condition(),
1001 selectOperands.true_value(),
1002 selectOperands.false_value());
1003 return success();
1004 }
1005
1006 //===----------------------------------------------------------------------===//
1007 // StoreOp
1008 //===----------------------------------------------------------------------===//
1009
1010 LogicalResult
matchAndRewrite(StoreOp storeOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1011 IntStoreOpPattern::matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
1012 ConversionPatternRewriter &rewriter) const {
1013 StoreOpAdaptor storeOperands(operands);
1014 auto memrefType = storeOp.memref().getType().cast<MemRefType>();
1015 if (!memrefType.getElementType().isSignlessInteger())
1016 return failure();
1017
1018 auto loc = storeOp.getLoc();
1019 auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
1020 spirv::AccessChainOp accessChainOp =
1021 spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(),
1022 storeOperands.indices(), loc, rewriter);
1023 int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
1024 auto dstType = typeConverter.convertType(memrefType)
1025 .cast<spirv::PointerType>()
1026 .getPointeeType()
1027 .cast<spirv::StructType>()
1028 .getElementType(0)
1029 .cast<spirv::ArrayType>()
1030 .getElementType();
1031 int dstBits = dstType.getIntOrFloatBitWidth();
1032 assert(dstBits % srcBits == 0);
1033
1034 if (srcBits == dstBits) {
1035 rewriter.replaceOpWithNewOp<spirv::StoreOp>(
1036 storeOp, accessChainOp.getResult(), storeOperands.value());
1037 return success();
1038 }
1039
1040 // Since there are multi threads in the processing, the emulation will be done
1041 // with atomic operations. E.g., if the storing value is i8, rewrite the
1042 // StoreOp to
1043 // 1) load a 32-bit integer
1044 // 2) clear 8 bits in the loading value
1045 // 3) store 32-bit value back
1046 // 4) load a 32-bit integer
1047 // 5) modify 8 bits in the loading value
1048 // 6) store 32-bit value back
1049 // The step 1 to step 3 are done by AtomicAnd as one atomic step, and the step
1050 // 4 to step 6 are done by AtomicOr as another atomic step.
1051 assert(accessChainOp.indices().size() == 2);
1052 Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
1053 Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
1054
1055 // Create a mask to clear the destination. E.g., if it is the second i8 in
1056 // i32, 0xFFFF00FF is created.
1057 Value mask = rewriter.create<spirv::ConstantOp>(
1058 loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
1059 Value clearBitsMask =
1060 rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, mask, offset);
1061 clearBitsMask = rewriter.create<spirv::NotOp>(loc, dstType, clearBitsMask);
1062
1063 Value storeVal =
1064 shiftValue(loc, storeOperands.value(), offset, mask, dstBits, rewriter);
1065 Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
1066 srcBits, dstBits, rewriter);
1067 Optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
1068 if (!scope)
1069 return failure();
1070 Value result = rewriter.create<spirv::AtomicAndOp>(
1071 loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
1072 clearBitsMask);
1073 result = rewriter.create<spirv::AtomicOrOp>(
1074 loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
1075 storeVal);
1076
1077 // The AtomicOrOp has no side effect. Since it is already inserted, we can
1078 // just remove the original StoreOp. Note that rewriter.replaceOp()
1079 // doesn't work because it only accepts that the numbers of result are the
1080 // same.
1081 rewriter.eraseOp(storeOp);
1082
1083 assert(accessChainOp.use_empty());
1084 rewriter.eraseOp(accessChainOp);
1085
1086 return success();
1087 }
1088
1089 LogicalResult
matchAndRewrite(StoreOp storeOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1090 StoreOpPattern::matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
1091 ConversionPatternRewriter &rewriter) const {
1092 StoreOpAdaptor storeOperands(operands);
1093 auto memrefType = storeOp.memref().getType().cast<MemRefType>();
1094 if (memrefType.getElementType().isSignlessInteger())
1095 return failure();
1096 auto storePtr =
1097 spirv::getElementPtr(*getTypeConverter<SPIRVTypeConverter>(), memrefType,
1098 storeOperands.memref(), storeOperands.indices(),
1099 storeOp.getLoc(), rewriter);
1100 rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
1101 storeOperands.value());
1102 return success();
1103 }
1104
1105 //===----------------------------------------------------------------------===//
1106 // XorOp
1107 //===----------------------------------------------------------------------===//
1108
1109 LogicalResult
matchAndRewrite(XOrOp xorOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1110 XOrOpPattern::matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands,
1111 ConversionPatternRewriter &rewriter) const {
1112 assert(operands.size() == 2);
1113
1114 if (isBoolScalarOrVector(operands.front().getType()))
1115 return failure();
1116
1117 auto dstType = getTypeConverter()->convertType(xorOp.getType());
1118 if (!dstType)
1119 return failure();
1120 rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(xorOp, dstType, operands);
1121
1122 return success();
1123 }
1124
1125 //===----------------------------------------------------------------------===//
1126 // Pattern population
1127 //===----------------------------------------------------------------------===//
1128
1129 namespace mlir {
populateStandardToSPIRVPatterns(MLIRContext * context,SPIRVTypeConverter & typeConverter,OwningRewritePatternList & patterns)1130 void populateStandardToSPIRVPatterns(MLIRContext *context,
1131 SPIRVTypeConverter &typeConverter,
1132 OwningRewritePatternList &patterns) {
1133 patterns.insert<
1134 // Unary and binary patterns
1135 BitwiseOpPattern<AndOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
1136 BitwiseOpPattern<OrOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
1137 UnaryAndBinaryOpPattern<AbsFOp, spirv::GLSLFAbsOp>,
1138 UnaryAndBinaryOpPattern<AddFOp, spirv::FAddOp>,
1139 UnaryAndBinaryOpPattern<AddIOp, spirv::IAddOp>,
1140 UnaryAndBinaryOpPattern<CeilFOp, spirv::GLSLCeilOp>,
1141 UnaryAndBinaryOpPattern<CosOp, spirv::GLSLCosOp>,
1142 UnaryAndBinaryOpPattern<DivFOp, spirv::FDivOp>,
1143 UnaryAndBinaryOpPattern<ExpOp, spirv::GLSLExpOp>,
1144 UnaryAndBinaryOpPattern<FloorFOp, spirv::GLSLFloorOp>,
1145 UnaryAndBinaryOpPattern<LogOp, spirv::GLSLLogOp>,
1146 UnaryAndBinaryOpPattern<MulFOp, spirv::FMulOp>,
1147 UnaryAndBinaryOpPattern<MulIOp, spirv::IMulOp>,
1148 UnaryAndBinaryOpPattern<NegFOp, spirv::FNegateOp>,
1149 UnaryAndBinaryOpPattern<RemFOp, spirv::FRemOp>,
1150 UnaryAndBinaryOpPattern<RsqrtOp, spirv::GLSLInverseSqrtOp>,
1151 UnaryAndBinaryOpPattern<ShiftLeftOp, spirv::ShiftLeftLogicalOp>,
1152 UnaryAndBinaryOpPattern<SignedDivIOp, spirv::SDivOp>,
1153 UnaryAndBinaryOpPattern<SignedShiftRightOp,
1154 spirv::ShiftRightArithmeticOp>,
1155 UnaryAndBinaryOpPattern<SinOp, spirv::GLSLSinOp>,
1156 UnaryAndBinaryOpPattern<SqrtOp, spirv::GLSLSqrtOp>,
1157 UnaryAndBinaryOpPattern<SubFOp, spirv::FSubOp>,
1158 UnaryAndBinaryOpPattern<SubIOp, spirv::ISubOp>,
1159 UnaryAndBinaryOpPattern<TanhOp, spirv::GLSLTanhOp>,
1160 UnaryAndBinaryOpPattern<UnsignedDivIOp, spirv::UDivOp>,
1161 UnaryAndBinaryOpPattern<UnsignedRemIOp, spirv::UModOp>,
1162 UnaryAndBinaryOpPattern<UnsignedShiftRightOp, spirv::ShiftRightLogicalOp>,
1163 SignedRemIOpPattern, XOrOpPattern,
1164
1165 // Comparison patterns
1166 BoolCmpIOpPattern, CmpFOpPattern, CmpFOpNanNonePattern, CmpIOpPattern,
1167
1168 // Constant patterns
1169 ConstantCompositeOpPattern, ConstantScalarOpPattern,
1170
1171 // Memory patterns
1172 AllocOpPattern, DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern,
1173 LoadOpPattern, StoreOpPattern,
1174
1175 ReturnOpPattern, SelectOpPattern,
1176
1177 // Type cast patterns
1178 UIToFPI1Pattern, ZeroExtendI1Pattern,
1179 TypeCastingOpPattern<IndexCastOp, spirv::SConvertOp>,
1180 TypeCastingOpPattern<SIToFPOp, spirv::ConvertSToFOp>,
1181 TypeCastingOpPattern<UIToFPOp, spirv::ConvertUToFOp>,
1182 TypeCastingOpPattern<ZeroExtendIOp, spirv::UConvertOp>,
1183 TypeCastingOpPattern<TruncateIOp, spirv::SConvertOp>,
1184 TypeCastingOpPattern<FPToSIOp, spirv::ConvertFToSOp>,
1185 TypeCastingOpPattern<FPExtOp, spirv::FConvertOp>,
1186 TypeCastingOpPattern<FPTruncOp, spirv::FConvertOp>>(typeConverter,
1187 context);
1188
1189 // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel
1190 // capability is available.
1191 patterns.insert<CmpFOpNanKernelPattern>(typeConverter, context,
1192 /*benefit=*/2);
1193 }
1194 } // namespace mlir
1195