1 //===- StandardToSPIRV.cpp - Standard to SPIR-V Patterns ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert standard dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
14 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
15 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
16 #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"
18 #include "mlir/IR/AffineMap.h"
19 #include "mlir/Support/LogicalResult.h"
20 #include "llvm/ADT/SetVector.h"
21 #include "llvm/Support/Debug.h"
22 
23 #define DEBUG_TYPE "std-to-spirv-pattern"
24 
25 using namespace mlir;
26 
27 //===----------------------------------------------------------------------===//
28 // Utility functions
29 //===----------------------------------------------------------------------===//
30 
31 /// Returns true if the given `type` is a boolean scalar or vector type.
isBoolScalarOrVector(Type type)32 static bool isBoolScalarOrVector(Type type) {
33   if (type.isInteger(1))
34     return true;
35   if (auto vecType = type.dyn_cast<VectorType>())
36     return vecType.getElementType().isInteger(1);
37   return false;
38 }
39 
40 /// Converts the given `srcAttr` into a boolean attribute if it holds an
41 /// integral value. Returns null attribute if conversion fails.
convertBoolAttr(Attribute srcAttr,Builder builder)42 static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) {
43   if (auto boolAttr = srcAttr.dyn_cast<BoolAttr>())
44     return boolAttr;
45   if (auto intAttr = srcAttr.dyn_cast<IntegerAttr>())
46     return builder.getBoolAttr(intAttr.getValue().getBoolValue());
47   return BoolAttr();
48 }
49 
50 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
51 /// Returns null attribute if conversion fails.
convertIntegerAttr(IntegerAttr srcAttr,IntegerType dstType,Builder builder)52 static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType,
53                                       Builder builder) {
54   // If the source number uses less active bits than the target bitwidth, then
55   // it should be safe to convert.
56   if (srcAttr.getValue().isIntN(dstType.getWidth()))
57     return builder.getIntegerAttr(dstType, srcAttr.getInt());
58 
59   // XXX: Try again by interpreting the source number as a signed value.
60   // Although integers in the standard dialect are signless, they can represent
61   // a signed number. It's the operation decides how to interpret. This is
62   // dangerous, but it seems there is no good way of handling this if we still
63   // want to change the bitwidth. Emit a message at least.
64   if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
65     auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt());
66     LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '"
67                             << dstAttr << "' for type '" << dstType << "'\n");
68     return dstAttr;
69   }
70 
71   LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr
72                           << "' illegal: cannot fit into target type '"
73                           << dstType << "'\n");
74   return IntegerAttr();
75 }
76 
77 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
78 /// Returns null attribute if `dstType` is not 32-bit or conversion fails.
convertFloatAttr(FloatAttr srcAttr,FloatType dstType,Builder builder)79 static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
80                                   Builder builder) {
81   // Only support converting to float for now.
82   if (!dstType.isF32())
83     return FloatAttr();
84 
85   // Try to convert the source floating-point number to single precision.
86   APFloat dstVal = srcAttr.getValue();
87   bool losesInfo = false;
88   APFloat::opStatus status =
89       dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
90   if (status != APFloat::opOK || losesInfo) {
91     LLVM_DEBUG(llvm::dbgs()
92                << srcAttr << " illegal: cannot fit into converted type '"
93                << dstType << "'\n");
94     return FloatAttr();
95   }
96 
97   return builder.getF32FloatAttr(dstVal.convertToFloat());
98 }
99 
100 /// Returns signed remainder for `lhs` and `rhs` and lets the result follow
101 /// the sign of `signOperand`.
102 ///
103 /// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment
104 /// spec, "for the OpSRem and OpSMod instructions, if either operand is negative
105 /// the result is undefined."  So we cannot directly use spv.SRem/spv.SMod
106 /// if either operand can be negative. Emulate it via spv.UMod.
emulateSignedRemainder(Location loc,Value lhs,Value rhs,Value signOperand,OpBuilder & builder)107 static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
108                                     Value signOperand, OpBuilder &builder) {
109   assert(lhs.getType() == rhs.getType());
110   assert(lhs == signOperand || rhs == signOperand);
111 
112   Type type = lhs.getType();
113 
114   // Calculate the remainder with spv.UMod.
115   Value lhsAbs = builder.create<spirv::GLSLSAbsOp>(loc, type, lhs);
116   Value rhsAbs = builder.create<spirv::GLSLSAbsOp>(loc, type, rhs);
117   Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs);
118 
119   // Fix the sign.
120   Value isPositive;
121   if (lhs == signOperand)
122     isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs);
123   else
124     isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs);
125   Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs);
126   return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate);
127 }
128 
129 /// Returns the offset of the value in `targetBits` representation.
130 ///
131 /// `srcIdx` is an index into a 1-D array with each element having `sourceBits`.
132 /// It's assumed to be non-negative.
133 ///
134 /// When accessing an element in the array treating as having elements of
135 /// `targetBits`, multiple values are loaded in the same time. The method
136 /// returns the offset where the `srcIdx` locates in the value. For example, if
137 /// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is
138 /// located at (x % 4) * 8. Because there are four elements in one i32, and one
139 /// element has 8 bits.
getOffsetForBitwidth(Location loc,Value srcIdx,int sourceBits,int targetBits,OpBuilder & builder)140 static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
141                                   int targetBits, OpBuilder &builder) {
142   assert(targetBits % sourceBits == 0);
143   IntegerType targetType = builder.getIntegerType(targetBits);
144   IntegerAttr idxAttr =
145       builder.getIntegerAttr(targetType, targetBits / sourceBits);
146   auto idx = builder.create<spirv::ConstantOp>(loc, targetType, idxAttr);
147   IntegerAttr srcBitsAttr = builder.getIntegerAttr(targetType, sourceBits);
148   auto srcBitsValue =
149       builder.create<spirv::ConstantOp>(loc, targetType, srcBitsAttr);
150   auto m = builder.create<spirv::UModOp>(loc, srcIdx, idx);
151   return builder.create<spirv::IMulOp>(loc, targetType, m, srcBitsValue);
152 }
153 
154 /// Returns an adjusted spirv::AccessChainOp. Based on the
155 /// extension/capabilities, certain integer bitwidths `sourceBits` might not be
156 /// supported. During conversion if a memref of an unsupported type is used,
157 /// load/stores to this memref need to be modified to use a supported higher
158 /// bitwidth `targetBits` and extracting the required bits. For an accessing a
159 /// 1D array (spv.array or spv.rt_array), the last index is modified to load the
160 /// bits needed. The extraction of the actual bits needed are handled
161 /// separately. Note that this only works for a 1-D tensor.
adjustAccessChainForBitwidth(SPIRVTypeConverter & typeConverter,spirv::AccessChainOp op,int sourceBits,int targetBits,OpBuilder & builder)162 static Value adjustAccessChainForBitwidth(SPIRVTypeConverter &typeConverter,
163                                           spirv::AccessChainOp op,
164                                           int sourceBits, int targetBits,
165                                           OpBuilder &builder) {
166   assert(targetBits % sourceBits == 0);
167   const auto loc = op.getLoc();
168   IntegerType targetType = builder.getIntegerType(targetBits);
169   IntegerAttr attr =
170       builder.getIntegerAttr(targetType, targetBits / sourceBits);
171   auto idx = builder.create<spirv::ConstantOp>(loc, targetType, attr);
172   auto lastDim = op->getOperand(op.getNumOperands() - 1);
173   auto indices = llvm::to_vector<4>(op.indices());
174   // There are two elements if this is a 1-D tensor.
175   assert(indices.size() == 2);
176   indices.back() = builder.create<spirv::SDivOp>(loc, lastDim, idx);
177   Type t = typeConverter.convertType(op.component_ptr().getType());
178   return builder.create<spirv::AccessChainOp>(loc, t, op.base_ptr(), indices);
179 }
180 
181 /// Returns the shifted `targetBits`-bit value with the given offset.
shiftValue(Location loc,Value value,Value offset,Value mask,int targetBits,OpBuilder & builder)182 static Value shiftValue(Location loc, Value value, Value offset, Value mask,
183                         int targetBits, OpBuilder &builder) {
184   Type targetType = builder.getIntegerType(targetBits);
185   Value result = builder.create<spirv::BitwiseAndOp>(loc, value, mask);
186   return builder.create<spirv::ShiftLeftLogicalOp>(loc, targetType, result,
187                                                    offset);
188 }
189 
190 /// Returns true if the allocations of type `t` can be lowered to SPIR-V.
isAllocationSupported(MemRefType t)191 static bool isAllocationSupported(MemRefType t) {
192   // Currently only support workgroup local memory allocations with static
193   // shape and int or float or vector of int or float element type.
194   if (!(t.hasStaticShape() &&
195         SPIRVTypeConverter::getMemorySpaceForStorageClass(
196             spirv::StorageClass::Workgroup) == t.getMemorySpace()))
197     return false;
198   Type elementType = t.getElementType();
199   if (auto vecType = elementType.dyn_cast<VectorType>())
200     elementType = vecType.getElementType();
201   return elementType.isIntOrFloat();
202 }
203 
204 /// Returns the scope to use for atomic operations use for emulating store
205 /// operations of unsupported integer bitwidths, based on the memref
206 /// type. Returns None on failure.
getAtomicOpScope(MemRefType t)207 static Optional<spirv::Scope> getAtomicOpScope(MemRefType t) {
208   Optional<spirv::StorageClass> storageClass =
209       SPIRVTypeConverter::getStorageClassForMemorySpace(t.getMemorySpace());
210   if (!storageClass)
211     return {};
212   switch (*storageClass) {
213   case spirv::StorageClass::StorageBuffer:
214     return spirv::Scope::Device;
215   case spirv::StorageClass::Workgroup:
216     return spirv::Scope::Workgroup;
217   default: {
218   }
219   }
220   return {};
221 }
222 
223 //===----------------------------------------------------------------------===//
224 // Operation conversion
225 //===----------------------------------------------------------------------===//
226 
227 // Note that DRR cannot be used for the patterns in this file: we may need to
228 // convert type along the way, which requires ConversionPattern. DRR generates
229 // normal RewritePattern.
230 
231 namespace {
232 
233 /// Converts an allocation operation to SPIR-V. Currently only supports lowering
234 /// to Workgroup memory when the size is constant.  Note that this pattern needs
235 /// to be applied in a pass that runs at least at spv.module scope since it wil
236 /// ladd global variables into the spv.module.
237 class AllocOpPattern final : public OpConversionPattern<AllocOp> {
238 public:
239   using OpConversionPattern<AllocOp>::OpConversionPattern;
240 
241   LogicalResult
matchAndRewrite(AllocOp operation,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const242   matchAndRewrite(AllocOp operation, ArrayRef<Value> operands,
243                   ConversionPatternRewriter &rewriter) const override {
244     MemRefType allocType = operation.getType();
245     if (!isAllocationSupported(allocType))
246       return operation.emitError("unhandled allocation type");
247 
248     // Get the SPIR-V type for the allocation.
249     Type spirvType = getTypeConverter()->convertType(allocType);
250 
251     // Insert spv.globalVariable for this allocation.
252     Operation *parent =
253         SymbolTable::getNearestSymbolTable(operation->getParentOp());
254     if (!parent)
255       return failure();
256     Location loc = operation.getLoc();
257     spirv::GlobalVariableOp varOp;
258     {
259       OpBuilder::InsertionGuard guard(rewriter);
260       Block &entryBlock = *parent->getRegion(0).begin();
261       rewriter.setInsertionPointToStart(&entryBlock);
262       auto varOps = entryBlock.getOps<spirv::GlobalVariableOp>();
263       std::string varName =
264           std::string("__workgroup_mem__") +
265           std::to_string(std::distance(varOps.begin(), varOps.end()));
266       varOp = rewriter.create<spirv::GlobalVariableOp>(
267           loc, TypeAttr::get(spirvType), varName,
268           /*initializer = */ nullptr);
269     }
270 
271     // Get pointer to global variable at the current scope.
272     rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(operation, varOp);
273     return success();
274   }
275 };
276 
277 /// Removed a deallocation if it is a supported allocation. Currently only
278 /// removes deallocation if the memory space is workgroup memory.
279 class DeallocOpPattern final : public OpConversionPattern<DeallocOp> {
280 public:
281   using OpConversionPattern<DeallocOp>::OpConversionPattern;
282 
283   LogicalResult
matchAndRewrite(DeallocOp operation,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const284   matchAndRewrite(DeallocOp operation, ArrayRef<Value> operands,
285                   ConversionPatternRewriter &rewriter) const override {
286     MemRefType deallocType = operation.memref().getType().cast<MemRefType>();
287     if (!isAllocationSupported(deallocType))
288       return operation.emitError("unhandled deallocation type");
289     rewriter.eraseOp(operation);
290     return success();
291   }
292 };
293 
294 /// Converts unary and binary standard operations to SPIR-V operations.
295 template <typename StdOp, typename SPIRVOp>
296 class UnaryAndBinaryOpPattern final : public OpConversionPattern<StdOp> {
297 public:
298   using OpConversionPattern<StdOp>::OpConversionPattern;
299 
300   LogicalResult
matchAndRewrite(StdOp operation,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const301   matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
302                   ConversionPatternRewriter &rewriter) const override {
303     assert(operands.size() <= 2);
304     auto dstType = this->getTypeConverter()->convertType(operation.getType());
305     if (!dstType)
306       return failure();
307     if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
308         dstType != operation.getType()) {
309       return operation.emitError(
310           "bitwidth emulation is not implemented yet on unsigned op");
311     }
312     rewriter.template replaceOpWithNewOp<SPIRVOp>(operation, dstType, operands);
313     return success();
314   }
315 };
316 
317 /// Converts std.remi_signed to SPIR-V ops.
318 ///
319 /// This cannot be merged into the template unary/binary pattern due to
320 /// Vulkan restrictions over spv.SRem and spv.SMod.
321 class SignedRemIOpPattern final : public OpConversionPattern<SignedRemIOp> {
322 public:
323   using OpConversionPattern<SignedRemIOp>::OpConversionPattern;
324 
325   LogicalResult
326   matchAndRewrite(SignedRemIOp remOp, ArrayRef<Value> operands,
327                   ConversionPatternRewriter &rewriter) const override;
328 };
329 
330 /// Converts bitwise standard operations to SPIR-V operations. This is a special
331 /// pattern other than the BinaryOpPatternPattern because if the operands are
332 /// boolean values, SPIR-V uses different operations (`SPIRVLogicalOp`). For
333 /// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`.
334 template <typename StdOp, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
335 class BitwiseOpPattern final : public OpConversionPattern<StdOp> {
336 public:
337   using OpConversionPattern<StdOp>::OpConversionPattern;
338 
339   LogicalResult
matchAndRewrite(StdOp operation,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const340   matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
341                   ConversionPatternRewriter &rewriter) const override {
342     assert(operands.size() == 2);
343     auto dstType =
344         this->getTypeConverter()->convertType(operation.getResult().getType());
345     if (!dstType)
346       return failure();
347     if (isBoolScalarOrVector(operands.front().getType())) {
348       rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(operation, dstType,
349                                                            operands);
350     } else {
351       rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(operation, dstType,
352                                                            operands);
353     }
354     return success();
355   }
356 };
357 
358 /// Converts composite std.constant operation to spv.constant.
359 class ConstantCompositeOpPattern final
360     : public OpConversionPattern<ConstantOp> {
361 public:
362   using OpConversionPattern<ConstantOp>::OpConversionPattern;
363 
364   LogicalResult
365   matchAndRewrite(ConstantOp constOp, ArrayRef<Value> operands,
366                   ConversionPatternRewriter &rewriter) const override;
367 };
368 
369 /// Converts scalar std.constant operation to spv.constant.
370 class ConstantScalarOpPattern final : public OpConversionPattern<ConstantOp> {
371 public:
372   using OpConversionPattern<ConstantOp>::OpConversionPattern;
373 
374   LogicalResult
375   matchAndRewrite(ConstantOp constOp, ArrayRef<Value> operands,
376                   ConversionPatternRewriter &rewriter) const override;
377 };
378 
379 /// Converts floating-point comparison operations to SPIR-V ops.
380 class CmpFOpPattern final : public OpConversionPattern<CmpFOp> {
381 public:
382   using OpConversionPattern<CmpFOp>::OpConversionPattern;
383 
384   LogicalResult
385   matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
386                   ConversionPatternRewriter &rewriter) const override;
387 };
388 
389 /// Converts floating point NaN check to SPIR-V ops. This pattern requires
390 /// Kernel capability.
391 class CmpFOpNanKernelPattern final : public OpConversionPattern<CmpFOp> {
392 public:
393   using OpConversionPattern<CmpFOp>::OpConversionPattern;
394 
395   LogicalResult
396   matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
397                   ConversionPatternRewriter &rewriter) const override;
398 };
399 
400 /// Converts floating point NaN check to SPIR-V ops. This pattern does not
401 /// require additional capability.
402 class CmpFOpNanNonePattern final : public OpConversionPattern<CmpFOp> {
403 public:
404   using OpConversionPattern<CmpFOp>::OpConversionPattern;
405 
406   LogicalResult
407   matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
408                   ConversionPatternRewriter &rewriter) const override;
409 };
410 
411 /// Converts integer compare operation on i1 type operands to SPIR-V ops.
412 class BoolCmpIOpPattern final : public OpConversionPattern<CmpIOp> {
413 public:
414   using OpConversionPattern<CmpIOp>::OpConversionPattern;
415 
416   LogicalResult
417   matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
418                   ConversionPatternRewriter &rewriter) const override;
419 };
420 
421 /// Converts integer compare operation to SPIR-V ops.
422 class CmpIOpPattern final : public OpConversionPattern<CmpIOp> {
423 public:
424   using OpConversionPattern<CmpIOp>::OpConversionPattern;
425 
426   LogicalResult
427   matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
428                   ConversionPatternRewriter &rewriter) const override;
429 };
430 
431 /// Converts std.load to spv.Load.
432 class IntLoadOpPattern final : public OpConversionPattern<LoadOp> {
433 public:
434   using OpConversionPattern<LoadOp>::OpConversionPattern;
435 
436   LogicalResult
437   matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
438                   ConversionPatternRewriter &rewriter) const override;
439 };
440 
441 /// Converts std.load to spv.Load.
442 class LoadOpPattern final : public OpConversionPattern<LoadOp> {
443 public:
444   using OpConversionPattern<LoadOp>::OpConversionPattern;
445 
446   LogicalResult
447   matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
448                   ConversionPatternRewriter &rewriter) const override;
449 };
450 
451 /// Converts std.return to spv.Return.
452 class ReturnOpPattern final : public OpConversionPattern<ReturnOp> {
453 public:
454   using OpConversionPattern<ReturnOp>::OpConversionPattern;
455 
456   LogicalResult
457   matchAndRewrite(ReturnOp returnOp, ArrayRef<Value> operands,
458                   ConversionPatternRewriter &rewriter) const override;
459 };
460 
461 /// Converts std.select to spv.Select.
462 class SelectOpPattern final : public OpConversionPattern<SelectOp> {
463 public:
464   using OpConversionPattern<SelectOp>::OpConversionPattern;
465   LogicalResult
466   matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
467                   ConversionPatternRewriter &rewriter) const override;
468 };
469 
470 /// Converts std.store to spv.Store on integers.
471 class IntStoreOpPattern final : public OpConversionPattern<StoreOp> {
472 public:
473   using OpConversionPattern<StoreOp>::OpConversionPattern;
474 
475   LogicalResult
476   matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
477                   ConversionPatternRewriter &rewriter) const override;
478 };
479 
480 /// Converts std.store to spv.Store.
481 class StoreOpPattern final : public OpConversionPattern<StoreOp> {
482 public:
483   using OpConversionPattern<StoreOp>::OpConversionPattern;
484 
485   LogicalResult
486   matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
487                   ConversionPatternRewriter &rewriter) const override;
488 };
489 
490 /// Converts std.zexti to spv.Select if the type of source is i1 or vector of
491 /// i1.
492 class ZeroExtendI1Pattern final : public OpConversionPattern<ZeroExtendIOp> {
493 public:
494   using OpConversionPattern<ZeroExtendIOp>::OpConversionPattern;
495 
496   LogicalResult
matchAndRewrite(ZeroExtendIOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const497   matchAndRewrite(ZeroExtendIOp op, ArrayRef<Value> operands,
498                   ConversionPatternRewriter &rewriter) const override {
499     auto srcType = operands.front().getType();
500     if (!isBoolScalarOrVector(srcType))
501       return failure();
502 
503     auto dstType =
504         this->getTypeConverter()->convertType(op.getResult().getType());
505     Location loc = op.getLoc();
506     Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
507     Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
508     rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
509         op, dstType, operands.front(), one, zero);
510     return success();
511   }
512 };
513 
514 /// Converts std.uitofp to spv.Select if the type of source is i1 or vector of
515 /// i1.
516 class UIToFPI1Pattern final : public OpConversionPattern<UIToFPOp> {
517 public:
518   using OpConversionPattern<UIToFPOp>::OpConversionPattern;
519 
520   LogicalResult
matchAndRewrite(UIToFPOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const521   matchAndRewrite(UIToFPOp op, ArrayRef<Value> operands,
522                   ConversionPatternRewriter &rewriter) const override {
523     auto srcType = operands.front().getType();
524     if (!isBoolScalarOrVector(srcType))
525       return failure();
526 
527     auto dstType =
528         this->getTypeConverter()->convertType(op.getResult().getType());
529     Location loc = op.getLoc();
530     Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
531     Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
532     rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
533         op, dstType, operands.front(), one, zero);
534     return success();
535   }
536 };
537 
538 /// Converts type-casting standard operations to SPIR-V operations.
539 template <typename StdOp, typename SPIRVOp>
540 class TypeCastingOpPattern final : public OpConversionPattern<StdOp> {
541 public:
542   using OpConversionPattern<StdOp>::OpConversionPattern;
543 
544   LogicalResult
matchAndRewrite(StdOp operation,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const545   matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
546                   ConversionPatternRewriter &rewriter) const override {
547     assert(operands.size() == 1);
548     auto srcType = operands.front().getType();
549     if (isBoolScalarOrVector(srcType))
550       return failure();
551     auto dstType =
552         this->getTypeConverter()->convertType(operation.getResult().getType());
553     if (dstType == srcType) {
554       // Due to type conversion, we are seeing the same source and target type.
555       // Then we can just erase this operation by forwarding its operand.
556       rewriter.replaceOp(operation, operands.front());
557     } else {
558       rewriter.template replaceOpWithNewOp<SPIRVOp>(operation, dstType,
559                                                     operands);
560     }
561     return success();
562   }
563 };
564 
565 /// Converts std.xor to SPIR-V operations.
566 class XOrOpPattern final : public OpConversionPattern<XOrOp> {
567 public:
568   using OpConversionPattern<XOrOp>::OpConversionPattern;
569 
570   LogicalResult
571   matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands,
572                   ConversionPatternRewriter &rewriter) const override;
573 };
574 
575 } // namespace
576 
577 //===----------------------------------------------------------------------===//
578 // SignedRemIOpPattern
579 //===----------------------------------------------------------------------===//
580 
matchAndRewrite(SignedRemIOp remOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const581 LogicalResult SignedRemIOpPattern::matchAndRewrite(
582     SignedRemIOp remOp, ArrayRef<Value> operands,
583     ConversionPatternRewriter &rewriter) const {
584   Value result = emulateSignedRemainder(remOp.getLoc(), operands[0],
585                                         operands[1], operands[0], rewriter);
586   rewriter.replaceOp(remOp, result);
587 
588   return success();
589 }
590 
591 //===----------------------------------------------------------------------===//
592 // ConstantOp with composite type.
593 //===----------------------------------------------------------------------===//
594 
matchAndRewrite(ConstantOp constOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const595 LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
596     ConstantOp constOp, ArrayRef<Value> operands,
597     ConversionPatternRewriter &rewriter) const {
598   auto srcType = constOp.getType().dyn_cast<ShapedType>();
599   if (!srcType)
600     return failure();
601 
602   // std.constant should only have vector or tenor types.
603   assert((srcType.isa<VectorType, RankedTensorType>()));
604 
605   auto dstType = getTypeConverter()->convertType(srcType);
606   if (!dstType)
607     return failure();
608 
609   auto dstElementsAttr = constOp.value().dyn_cast<DenseElementsAttr>();
610   ShapedType dstAttrType = dstElementsAttr.getType();
611   if (!dstElementsAttr)
612     return failure();
613 
614   // If the composite type has more than one dimensions, perform linearization.
615   if (srcType.getRank() > 1) {
616     if (srcType.isa<RankedTensorType>()) {
617       dstAttrType = RankedTensorType::get(srcType.getNumElements(),
618                                           srcType.getElementType());
619       dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
620     } else {
621       // TODO: add support for large vectors.
622       return failure();
623     }
624   }
625 
626   Type srcElemType = srcType.getElementType();
627   Type dstElemType;
628   // Tensor types are converted to SPIR-V array types; vector types are
629   // converted to SPIR-V vector/array types.
630   if (auto arrayType = dstType.dyn_cast<spirv::ArrayType>())
631     dstElemType = arrayType.getElementType();
632   else
633     dstElemType = dstType.cast<VectorType>().getElementType();
634 
635   // If the source and destination element types are different, perform
636   // attribute conversion.
637   if (srcElemType != dstElemType) {
638     SmallVector<Attribute, 8> elements;
639     if (srcElemType.isa<FloatType>()) {
640       for (Attribute srcAttr : dstElementsAttr.getAttributeValues()) {
641         FloatAttr dstAttr = convertFloatAttr(
642             srcAttr.cast<FloatAttr>(), dstElemType.cast<FloatType>(), rewriter);
643         if (!dstAttr)
644           return failure();
645         elements.push_back(dstAttr);
646       }
647     } else if (srcElemType.isInteger(1)) {
648       return failure();
649     } else {
650       for (Attribute srcAttr : dstElementsAttr.getAttributeValues()) {
651         IntegerAttr dstAttr =
652             convertIntegerAttr(srcAttr.cast<IntegerAttr>(),
653                                dstElemType.cast<IntegerType>(), rewriter);
654         if (!dstAttr)
655           return failure();
656         elements.push_back(dstAttr);
657       }
658     }
659 
660     // Unfortunately, we cannot use dialect-specific types for element
661     // attributes; element attributes only works with builtin types. So we need
662     // to prepare another converted builtin types for the destination elements
663     // attribute.
664     if (dstAttrType.isa<RankedTensorType>())
665       dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType);
666     else
667       dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
668 
669     dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
670   }
671 
672   rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
673                                                  dstElementsAttr);
674   return success();
675 }
676 
677 //===----------------------------------------------------------------------===//
678 // ConstantOp with scalar type.
679 //===----------------------------------------------------------------------===//
680 
matchAndRewrite(ConstantOp constOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const681 LogicalResult ConstantScalarOpPattern::matchAndRewrite(
682     ConstantOp constOp, ArrayRef<Value> operands,
683     ConversionPatternRewriter &rewriter) const {
684   Type srcType = constOp.getType();
685   if (!srcType.isIntOrIndexOrFloat())
686     return failure();
687 
688   Type dstType = getTypeConverter()->convertType(srcType);
689   if (!dstType)
690     return failure();
691 
692   // Floating-point types.
693   if (srcType.isa<FloatType>()) {
694     auto srcAttr = constOp.value().cast<FloatAttr>();
695     auto dstAttr = srcAttr;
696 
697     // Floating-point types not supported in the target environment are all
698     // converted to float type.
699     if (srcType != dstType) {
700       dstAttr = convertFloatAttr(srcAttr, dstType.cast<FloatType>(), rewriter);
701       if (!dstAttr)
702         return failure();
703     }
704 
705     rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
706     return success();
707   }
708 
709   // Bool type.
710   if (srcType.isInteger(1)) {
711     // std.constant can use 0/1 instead of true/false for i1 values. We need to
712     // handle that here.
713     auto dstAttr = convertBoolAttr(constOp.value(), rewriter);
714     if (!dstAttr)
715       return failure();
716     rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
717     return success();
718   }
719 
720   // IndexType or IntegerType. Index values are converted to 32-bit integer
721   // values when converting to SPIR-V.
722   auto srcAttr = constOp.value().cast<IntegerAttr>();
723   auto dstAttr =
724       convertIntegerAttr(srcAttr, dstType.cast<IntegerType>(), rewriter);
725   if (!dstAttr)
726     return failure();
727   rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
728   return success();
729 }
730 
731 //===----------------------------------------------------------------------===//
732 // CmpFOp
733 //===----------------------------------------------------------------------===//
734 
735 LogicalResult
matchAndRewrite(CmpFOp cmpFOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const736 CmpFOpPattern::matchAndRewrite(CmpFOp cmpFOp, ArrayRef<Value> operands,
737                                ConversionPatternRewriter &rewriter) const {
738   CmpFOpAdaptor cmpFOpOperands(operands);
739 
740   switch (cmpFOp.getPredicate()) {
741 #define DISPATCH(cmpPredicate, spirvOp)                                        \
742   case cmpPredicate:                                                           \
743     rewriter.replaceOpWithNewOp<spirvOp>(cmpFOp, cmpFOp.getResult().getType(), \
744                                          cmpFOpOperands.lhs(),                 \
745                                          cmpFOpOperands.rhs());                \
746     return success();
747 
748     // Ordered.
749     DISPATCH(CmpFPredicate::OEQ, spirv::FOrdEqualOp);
750     DISPATCH(CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
751     DISPATCH(CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
752     DISPATCH(CmpFPredicate::OLT, spirv::FOrdLessThanOp);
753     DISPATCH(CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
754     DISPATCH(CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
755     // Unordered.
756     DISPATCH(CmpFPredicate::UEQ, spirv::FUnordEqualOp);
757     DISPATCH(CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
758     DISPATCH(CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
759     DISPATCH(CmpFPredicate::ULT, spirv::FUnordLessThanOp);
760     DISPATCH(CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
761     DISPATCH(CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
762 
763 #undef DISPATCH
764 
765   default:
766     break;
767   }
768   return failure();
769 }
770 
matchAndRewrite(CmpFOp cmpFOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const771 LogicalResult CmpFOpNanKernelPattern::matchAndRewrite(
772     CmpFOp cmpFOp, ArrayRef<Value> operands,
773     ConversionPatternRewriter &rewriter) const {
774   CmpFOpAdaptor cmpFOpOperands(operands);
775 
776   if (cmpFOp.getPredicate() == CmpFPredicate::ORD) {
777     rewriter.replaceOpWithNewOp<spirv::OrderedOp>(cmpFOp, cmpFOpOperands.lhs(),
778                                                   cmpFOpOperands.rhs());
779     return success();
780   }
781 
782   if (cmpFOp.getPredicate() == CmpFPredicate::UNO) {
783     rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(
784         cmpFOp, cmpFOpOperands.lhs(), cmpFOpOperands.rhs());
785     return success();
786   }
787 
788   return failure();
789 }
790 
matchAndRewrite(CmpFOp cmpFOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const791 LogicalResult CmpFOpNanNonePattern::matchAndRewrite(
792     CmpFOp cmpFOp, ArrayRef<Value> operands,
793     ConversionPatternRewriter &rewriter) const {
794   if (cmpFOp.getPredicate() != CmpFPredicate::ORD &&
795       cmpFOp.getPredicate() != CmpFPredicate::UNO)
796     return failure();
797 
798   CmpFOpAdaptor cmpFOpOperands(operands);
799   Location loc = cmpFOp.getLoc();
800 
801   Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, cmpFOpOperands.lhs());
802   Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, cmpFOpOperands.rhs());
803 
804   Value replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan);
805   if (cmpFOp.getPredicate() == CmpFPredicate::ORD)
806     replace = rewriter.create<spirv::LogicalNotOp>(loc, replace);
807 
808   rewriter.replaceOp(cmpFOp, replace);
809   return success();
810 }
811 
812 //===----------------------------------------------------------------------===//
813 // CmpIOp
814 //===----------------------------------------------------------------------===//
815 
816 LogicalResult
matchAndRewrite(CmpIOp cmpIOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const817 BoolCmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
818                                    ConversionPatternRewriter &rewriter) const {
819   CmpIOpAdaptor cmpIOpOperands(operands);
820 
821   Type operandType = cmpIOp.lhs().getType();
822   if (!isBoolScalarOrVector(operandType))
823     return failure();
824 
825   switch (cmpIOp.getPredicate()) {
826 #define DISPATCH(cmpPredicate, spirvOp)                                        \
827   case cmpPredicate:                                                           \
828     rewriter.replaceOpWithNewOp<spirvOp>(cmpIOp, cmpIOp.getResult().getType(), \
829                                          cmpIOpOperands.lhs(),                 \
830                                          cmpIOpOperands.rhs());                \
831     return success();
832 
833     DISPATCH(CmpIPredicate::eq, spirv::LogicalEqualOp);
834     DISPATCH(CmpIPredicate::ne, spirv::LogicalNotEqualOp);
835 
836 #undef DISPATCH
837   default:;
838   }
839   return failure();
840 }
841 
842 LogicalResult
matchAndRewrite(CmpIOp cmpIOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const843 CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef<Value> operands,
844                                ConversionPatternRewriter &rewriter) const {
845   CmpIOpAdaptor cmpIOpOperands(operands);
846 
847   Type operandType = cmpIOp.lhs().getType();
848   if (isBoolScalarOrVector(operandType))
849     return failure();
850 
851   switch (cmpIOp.getPredicate()) {
852 #define DISPATCH(cmpPredicate, spirvOp)                                        \
853   case cmpPredicate:                                                           \
854     if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&            \
855         operandType != this->getTypeConverter()->convertType(operandType)) {   \
856       return cmpIOp.emitError(                                                 \
857           "bitwidth emulation is not implemented yet on unsigned op");         \
858     }                                                                          \
859     rewriter.replaceOpWithNewOp<spirvOp>(cmpIOp, cmpIOp.getResult().getType(), \
860                                          cmpIOpOperands.lhs(),                 \
861                                          cmpIOpOperands.rhs());                \
862     return success();
863 
864     DISPATCH(CmpIPredicate::eq, spirv::IEqualOp);
865     DISPATCH(CmpIPredicate::ne, spirv::INotEqualOp);
866     DISPATCH(CmpIPredicate::slt, spirv::SLessThanOp);
867     DISPATCH(CmpIPredicate::sle, spirv::SLessThanEqualOp);
868     DISPATCH(CmpIPredicate::sgt, spirv::SGreaterThanOp);
869     DISPATCH(CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
870     DISPATCH(CmpIPredicate::ult, spirv::ULessThanOp);
871     DISPATCH(CmpIPredicate::ule, spirv::ULessThanEqualOp);
872     DISPATCH(CmpIPredicate::ugt, spirv::UGreaterThanOp);
873     DISPATCH(CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
874 
875 #undef DISPATCH
876   }
877   return failure();
878 }
879 
880 //===----------------------------------------------------------------------===//
881 // LoadOp
882 //===----------------------------------------------------------------------===//
883 
884 LogicalResult
matchAndRewrite(LoadOp loadOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const885 IntLoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
886                                   ConversionPatternRewriter &rewriter) const {
887   LoadOpAdaptor loadOperands(operands);
888   auto loc = loadOp.getLoc();
889   auto memrefType = loadOp.memref().getType().cast<MemRefType>();
890   if (!memrefType.getElementType().isSignlessInteger())
891     return failure();
892 
893   auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
894   spirv::AccessChainOp accessChainOp =
895       spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(),
896                            loadOperands.indices(), loc, rewriter);
897 
898   int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
899   auto dstType = typeConverter.convertType(memrefType)
900                      .cast<spirv::PointerType>()
901                      .getPointeeType()
902                      .cast<spirv::StructType>()
903                      .getElementType(0)
904                      .cast<spirv::ArrayType>()
905                      .getElementType();
906   int dstBits = dstType.getIntOrFloatBitWidth();
907   assert(dstBits % srcBits == 0);
908 
909   // If the rewrited load op has the same bit width, use the loading value
910   // directly.
911   if (srcBits == dstBits) {
912     rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp,
913                                                accessChainOp.getResult());
914     return success();
915   }
916 
917   // Assume that getElementPtr() works linearizely. If it's a scalar, the method
918   // still returns a linearized accessing. If the accessing is not linearized,
919   // there will be offset issues.
920   assert(accessChainOp.indices().size() == 2);
921   Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
922                                                    srcBits, dstBits, rewriter);
923   Value spvLoadOp = rewriter.create<spirv::LoadOp>(
924       loc, dstType, adjustedPtr,
925       loadOp->getAttrOfType<IntegerAttr>(
926           spirv::attributeName<spirv::MemoryAccess>()),
927       loadOp->getAttrOfType<IntegerAttr>("alignment"));
928 
929   // Shift the bits to the rightmost.
930   // ____XXXX________ -> ____________XXXX
931   Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
932   Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
933   Value result = rewriter.create<spirv::ShiftRightArithmeticOp>(
934       loc, spvLoadOp.getType(), spvLoadOp, offset);
935 
936   // Apply the mask to extract corresponding bits.
937   Value mask = rewriter.create<spirv::ConstantOp>(
938       loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
939   result = rewriter.create<spirv::BitwiseAndOp>(loc, dstType, result, mask);
940 
941   // Apply sign extension on the loading value unconditionally. The signedness
942   // semantic is carried in the operator itself, we relies other pattern to
943   // handle the casting.
944   IntegerAttr shiftValueAttr =
945       rewriter.getIntegerAttr(dstType, dstBits - srcBits);
946   Value shiftValue =
947       rewriter.create<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
948   result = rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, result,
949                                                       shiftValue);
950   result = rewriter.create<spirv::ShiftRightArithmeticOp>(loc, dstType, result,
951                                                           shiftValue);
952   rewriter.replaceOp(loadOp, result);
953 
954   assert(accessChainOp.use_empty());
955   rewriter.eraseOp(accessChainOp);
956 
957   return success();
958 }
959 
960 LogicalResult
matchAndRewrite(LoadOp loadOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const961 LoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef<Value> operands,
962                                ConversionPatternRewriter &rewriter) const {
963   LoadOpAdaptor loadOperands(operands);
964   auto memrefType = loadOp.memref().getType().cast<MemRefType>();
965   if (memrefType.getElementType().isSignlessInteger())
966     return failure();
967   auto loadPtr = spirv::getElementPtr(
968       *getTypeConverter<SPIRVTypeConverter>(), memrefType,
969       loadOperands.memref(), loadOperands.indices(), loadOp.getLoc(), rewriter);
970   rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, loadPtr);
971   return success();
972 }
973 
974 //===----------------------------------------------------------------------===//
975 // ReturnOp
976 //===----------------------------------------------------------------------===//
977 
978 LogicalResult
matchAndRewrite(ReturnOp returnOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const979 ReturnOpPattern::matchAndRewrite(ReturnOp returnOp, ArrayRef<Value> operands,
980                                  ConversionPatternRewriter &rewriter) const {
981   if (returnOp.getNumOperands() > 1)
982     return failure();
983 
984   if (returnOp.getNumOperands() == 1) {
985     rewriter.replaceOpWithNewOp<spirv::ReturnValueOp>(returnOp, operands[0]);
986   } else {
987     rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
988   }
989   return success();
990 }
991 
992 //===----------------------------------------------------------------------===//
993 // SelectOp
994 //===----------------------------------------------------------------------===//
995 
996 LogicalResult
matchAndRewrite(SelectOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const997 SelectOpPattern::matchAndRewrite(SelectOp op, ArrayRef<Value> operands,
998                                  ConversionPatternRewriter &rewriter) const {
999   SelectOpAdaptor selectOperands(operands);
1000   rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, selectOperands.condition(),
1001                                                selectOperands.true_value(),
1002                                                selectOperands.false_value());
1003   return success();
1004 }
1005 
1006 //===----------------------------------------------------------------------===//
1007 // StoreOp
1008 //===----------------------------------------------------------------------===//
1009 
1010 LogicalResult
matchAndRewrite(StoreOp storeOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1011 IntStoreOpPattern::matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
1012                                    ConversionPatternRewriter &rewriter) const {
1013   StoreOpAdaptor storeOperands(operands);
1014   auto memrefType = storeOp.memref().getType().cast<MemRefType>();
1015   if (!memrefType.getElementType().isSignlessInteger())
1016     return failure();
1017 
1018   auto loc = storeOp.getLoc();
1019   auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
1020   spirv::AccessChainOp accessChainOp =
1021       spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(),
1022                            storeOperands.indices(), loc, rewriter);
1023   int srcBits = memrefType.getElementType().getIntOrFloatBitWidth();
1024   auto dstType = typeConverter.convertType(memrefType)
1025                      .cast<spirv::PointerType>()
1026                      .getPointeeType()
1027                      .cast<spirv::StructType>()
1028                      .getElementType(0)
1029                      .cast<spirv::ArrayType>()
1030                      .getElementType();
1031   int dstBits = dstType.getIntOrFloatBitWidth();
1032   assert(dstBits % srcBits == 0);
1033 
1034   if (srcBits == dstBits) {
1035     rewriter.replaceOpWithNewOp<spirv::StoreOp>(
1036         storeOp, accessChainOp.getResult(), storeOperands.value());
1037     return success();
1038   }
1039 
1040   // Since there are multi threads in the processing, the emulation will be done
1041   // with atomic operations. E.g., if the storing value is i8, rewrite the
1042   // StoreOp to
1043   // 1) load a 32-bit integer
1044   // 2) clear 8 bits in the loading value
1045   // 3) store 32-bit value back
1046   // 4) load a 32-bit integer
1047   // 5) modify 8 bits in the loading value
1048   // 6) store 32-bit value back
1049   // The step 1 to step 3 are done by AtomicAnd as one atomic step, and the step
1050   // 4 to step 6 are done by AtomicOr as another atomic step.
1051   assert(accessChainOp.indices().size() == 2);
1052   Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
1053   Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
1054 
1055   // Create a mask to clear the destination. E.g., if it is the second i8 in
1056   // i32, 0xFFFF00FF is created.
1057   Value mask = rewriter.create<spirv::ConstantOp>(
1058       loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
1059   Value clearBitsMask =
1060       rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, mask, offset);
1061   clearBitsMask = rewriter.create<spirv::NotOp>(loc, dstType, clearBitsMask);
1062 
1063   Value storeVal =
1064       shiftValue(loc, storeOperands.value(), offset, mask, dstBits, rewriter);
1065   Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
1066                                                    srcBits, dstBits, rewriter);
1067   Optional<spirv::Scope> scope = getAtomicOpScope(memrefType);
1068   if (!scope)
1069     return failure();
1070   Value result = rewriter.create<spirv::AtomicAndOp>(
1071       loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
1072       clearBitsMask);
1073   result = rewriter.create<spirv::AtomicOrOp>(
1074       loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease,
1075       storeVal);
1076 
1077   // The AtomicOrOp has no side effect. Since it is already inserted, we can
1078   // just remove the original StoreOp. Note that rewriter.replaceOp()
1079   // doesn't work because it only accepts that the numbers of result are the
1080   // same.
1081   rewriter.eraseOp(storeOp);
1082 
1083   assert(accessChainOp.use_empty());
1084   rewriter.eraseOp(accessChainOp);
1085 
1086   return success();
1087 }
1088 
1089 LogicalResult
matchAndRewrite(StoreOp storeOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1090 StoreOpPattern::matchAndRewrite(StoreOp storeOp, ArrayRef<Value> operands,
1091                                 ConversionPatternRewriter &rewriter) const {
1092   StoreOpAdaptor storeOperands(operands);
1093   auto memrefType = storeOp.memref().getType().cast<MemRefType>();
1094   if (memrefType.getElementType().isSignlessInteger())
1095     return failure();
1096   auto storePtr =
1097       spirv::getElementPtr(*getTypeConverter<SPIRVTypeConverter>(), memrefType,
1098                            storeOperands.memref(), storeOperands.indices(),
1099                            storeOp.getLoc(), rewriter);
1100   rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, storePtr,
1101                                               storeOperands.value());
1102   return success();
1103 }
1104 
1105 //===----------------------------------------------------------------------===//
1106 // XorOp
1107 //===----------------------------------------------------------------------===//
1108 
1109 LogicalResult
matchAndRewrite(XOrOp xorOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1110 XOrOpPattern::matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands,
1111                               ConversionPatternRewriter &rewriter) const {
1112   assert(operands.size() == 2);
1113 
1114   if (isBoolScalarOrVector(operands.front().getType()))
1115     return failure();
1116 
1117   auto dstType = getTypeConverter()->convertType(xorOp.getType());
1118   if (!dstType)
1119     return failure();
1120   rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(xorOp, dstType, operands);
1121 
1122   return success();
1123 }
1124 
1125 //===----------------------------------------------------------------------===//
1126 // Pattern population
1127 //===----------------------------------------------------------------------===//
1128 
1129 namespace mlir {
populateStandardToSPIRVPatterns(MLIRContext * context,SPIRVTypeConverter & typeConverter,OwningRewritePatternList & patterns)1130 void populateStandardToSPIRVPatterns(MLIRContext *context,
1131                                      SPIRVTypeConverter &typeConverter,
1132                                      OwningRewritePatternList &patterns) {
1133   patterns.insert<
1134       // Unary and binary patterns
1135       BitwiseOpPattern<AndOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
1136       BitwiseOpPattern<OrOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
1137       UnaryAndBinaryOpPattern<AbsFOp, spirv::GLSLFAbsOp>,
1138       UnaryAndBinaryOpPattern<AddFOp, spirv::FAddOp>,
1139       UnaryAndBinaryOpPattern<AddIOp, spirv::IAddOp>,
1140       UnaryAndBinaryOpPattern<CeilFOp, spirv::GLSLCeilOp>,
1141       UnaryAndBinaryOpPattern<CosOp, spirv::GLSLCosOp>,
1142       UnaryAndBinaryOpPattern<DivFOp, spirv::FDivOp>,
1143       UnaryAndBinaryOpPattern<ExpOp, spirv::GLSLExpOp>,
1144       UnaryAndBinaryOpPattern<FloorFOp, spirv::GLSLFloorOp>,
1145       UnaryAndBinaryOpPattern<LogOp, spirv::GLSLLogOp>,
1146       UnaryAndBinaryOpPattern<MulFOp, spirv::FMulOp>,
1147       UnaryAndBinaryOpPattern<MulIOp, spirv::IMulOp>,
1148       UnaryAndBinaryOpPattern<NegFOp, spirv::FNegateOp>,
1149       UnaryAndBinaryOpPattern<RemFOp, spirv::FRemOp>,
1150       UnaryAndBinaryOpPattern<RsqrtOp, spirv::GLSLInverseSqrtOp>,
1151       UnaryAndBinaryOpPattern<ShiftLeftOp, spirv::ShiftLeftLogicalOp>,
1152       UnaryAndBinaryOpPattern<SignedDivIOp, spirv::SDivOp>,
1153       UnaryAndBinaryOpPattern<SignedShiftRightOp,
1154                               spirv::ShiftRightArithmeticOp>,
1155       UnaryAndBinaryOpPattern<SinOp, spirv::GLSLSinOp>,
1156       UnaryAndBinaryOpPattern<SqrtOp, spirv::GLSLSqrtOp>,
1157       UnaryAndBinaryOpPattern<SubFOp, spirv::FSubOp>,
1158       UnaryAndBinaryOpPattern<SubIOp, spirv::ISubOp>,
1159       UnaryAndBinaryOpPattern<TanhOp, spirv::GLSLTanhOp>,
1160       UnaryAndBinaryOpPattern<UnsignedDivIOp, spirv::UDivOp>,
1161       UnaryAndBinaryOpPattern<UnsignedRemIOp, spirv::UModOp>,
1162       UnaryAndBinaryOpPattern<UnsignedShiftRightOp, spirv::ShiftRightLogicalOp>,
1163       SignedRemIOpPattern, XOrOpPattern,
1164 
1165       // Comparison patterns
1166       BoolCmpIOpPattern, CmpFOpPattern, CmpFOpNanNonePattern, CmpIOpPattern,
1167 
1168       // Constant patterns
1169       ConstantCompositeOpPattern, ConstantScalarOpPattern,
1170 
1171       // Memory patterns
1172       AllocOpPattern, DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern,
1173       LoadOpPattern, StoreOpPattern,
1174 
1175       ReturnOpPattern, SelectOpPattern,
1176 
1177       // Type cast patterns
1178       UIToFPI1Pattern, ZeroExtendI1Pattern,
1179       TypeCastingOpPattern<IndexCastOp, spirv::SConvertOp>,
1180       TypeCastingOpPattern<SIToFPOp, spirv::ConvertSToFOp>,
1181       TypeCastingOpPattern<UIToFPOp, spirv::ConvertUToFOp>,
1182       TypeCastingOpPattern<ZeroExtendIOp, spirv::UConvertOp>,
1183       TypeCastingOpPattern<TruncateIOp, spirv::SConvertOp>,
1184       TypeCastingOpPattern<FPToSIOp, spirv::ConvertFToSOp>,
1185       TypeCastingOpPattern<FPExtOp, spirv::FConvertOp>,
1186       TypeCastingOpPattern<FPTruncOp, spirv::FConvertOp>>(typeConverter,
1187                                                           context);
1188 
1189   // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel
1190   // capability is available.
1191   patterns.insert<CmpFOpNanKernelPattern>(typeConverter, context,
1192                                           /*benefit=*/2);
1193 }
1194 } // namespace mlir
1195