1 //===- SPIRVToLLVM.cpp - SPIR-V to LLVM Patterns --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert SPIR-V dialect to LLVM dialect.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h"
14 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
15 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
16 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
19 #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
20 #include "mlir/Dialect/StandardOps/IR/Ops.h"
21 #include "mlir/IR/BuiltinOps.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Support/LogicalResult.h"
24 #include "mlir/Transforms/DialectConversion.h"
25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/FormatVariadic.h"
27
28 #define DEBUG_TYPE "spirv-to-llvm-pattern"
29
30 using namespace mlir;
31
32 //===----------------------------------------------------------------------===//
33 // Utility functions
34 //===----------------------------------------------------------------------===//
35
36 /// Returns true if the given type is a signed integer or vector type.
isSignedIntegerOrVector(Type type)37 static bool isSignedIntegerOrVector(Type type) {
38 if (type.isSignedInteger())
39 return true;
40 if (auto vecType = type.dyn_cast<VectorType>())
41 return vecType.getElementType().isSignedInteger();
42 return false;
43 }
44
45 /// Returns true if the given type is an unsigned integer or vector type
isUnsignedIntegerOrVector(Type type)46 static bool isUnsignedIntegerOrVector(Type type) {
47 if (type.isUnsignedInteger())
48 return true;
49 if (auto vecType = type.dyn_cast<VectorType>())
50 return vecType.getElementType().isUnsignedInteger();
51 return false;
52 }
53
54 /// Returns the bit width of integer, float or vector of float or integer values
getBitWidth(Type type)55 static unsigned getBitWidth(Type type) {
56 assert((type.isIntOrFloat() || type.isa<VectorType>()) &&
57 "bitwidth is not supported for this type");
58 if (type.isIntOrFloat())
59 return type.getIntOrFloatBitWidth();
60 auto vecType = type.dyn_cast<VectorType>();
61 auto elementType = vecType.getElementType();
62 assert(elementType.isIntOrFloat() &&
63 "only integers and floats have a bitwidth");
64 return elementType.getIntOrFloatBitWidth();
65 }
66
67 /// Returns the bit width of LLVMType integer or vector.
getLLVMTypeBitWidth(Type type)68 static unsigned getLLVMTypeBitWidth(Type type) {
69 return (LLVM::isCompatibleVectorType(type) ? LLVM::getVectorElementType(type)
70 : type)
71 .cast<IntegerType>()
72 .getWidth();
73 }
74
75 /// Creates `IntegerAttribute` with all bits set for given type
minusOneIntegerAttribute(Type type,Builder builder)76 static IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) {
77 if (auto vecType = type.dyn_cast<VectorType>()) {
78 auto integerType = vecType.getElementType().cast<IntegerType>();
79 return builder.getIntegerAttr(integerType, -1);
80 }
81 auto integerType = type.cast<IntegerType>();
82 return builder.getIntegerAttr(integerType, -1);
83 }
84
85 /// Creates `llvm.mlir.constant` with all bits set for the given type.
createConstantAllBitsSet(Location loc,Type srcType,Type dstType,PatternRewriter & rewriter)86 static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType,
87 PatternRewriter &rewriter) {
88 if (srcType.isa<VectorType>()) {
89 return rewriter.create<LLVM::ConstantOp>(
90 loc, dstType,
91 SplatElementsAttr::get(srcType.cast<ShapedType>(),
92 minusOneIntegerAttribute(srcType, rewriter)));
93 }
94 return rewriter.create<LLVM::ConstantOp>(
95 loc, dstType, minusOneIntegerAttribute(srcType, rewriter));
96 }
97
98 /// Creates `llvm.mlir.constant` with a floating-point scalar or vector value.
createFPConstant(Location loc,Type srcType,Type dstType,PatternRewriter & rewriter,double value)99 static Value createFPConstant(Location loc, Type srcType, Type dstType,
100 PatternRewriter &rewriter, double value) {
101 if (auto vecType = srcType.dyn_cast<VectorType>()) {
102 auto floatType = vecType.getElementType().cast<FloatType>();
103 return rewriter.create<LLVM::ConstantOp>(
104 loc, dstType,
105 SplatElementsAttr::get(vecType,
106 rewriter.getFloatAttr(floatType, value)));
107 }
108 auto floatType = srcType.cast<FloatType>();
109 return rewriter.create<LLVM::ConstantOp>(
110 loc, dstType, rewriter.getFloatAttr(floatType, value));
111 }
112
113 /// Utility function for bitfield ops:
114 /// - `BitFieldInsert`
115 /// - `BitFieldSExtract`
116 /// - `BitFieldUExtract`
117 /// Truncates or extends the value. If the bitwidth of the value is the same as
118 /// `llvmType` bitwidth, the value remains unchanged.
optionallyTruncateOrExtend(Location loc,Value value,Type llvmType,PatternRewriter & rewriter)119 static Value optionallyTruncateOrExtend(Location loc, Value value,
120 Type llvmType,
121 PatternRewriter &rewriter) {
122 auto srcType = value.getType();
123 unsigned targetBitWidth = getLLVMTypeBitWidth(llvmType);
124 unsigned valueBitWidth = LLVM::isCompatibleType(srcType)
125 ? getLLVMTypeBitWidth(srcType)
126 : getBitWidth(srcType);
127
128 if (valueBitWidth < targetBitWidth)
129 return rewriter.create<LLVM::ZExtOp>(loc, llvmType, value);
130 // If the bit widths of `Count` and `Offset` are greater than the bit width
131 // of the target type, they are truncated. Truncation is safe since `Count`
132 // and `Offset` must be no more than 64 for op behaviour to be defined. Hence,
133 // both values can be expressed in 8 bits.
134 if (valueBitWidth > targetBitWidth)
135 return rewriter.create<LLVM::TruncOp>(loc, llvmType, value);
136 return value;
137 }
138
139 /// Broadcasts the value to vector with `numElements` number of elements.
broadcast(Location loc,Value toBroadcast,unsigned numElements,LLVMTypeConverter & typeConverter,ConversionPatternRewriter & rewriter)140 static Value broadcast(Location loc, Value toBroadcast, unsigned numElements,
141 LLVMTypeConverter &typeConverter,
142 ConversionPatternRewriter &rewriter) {
143 auto vectorType = VectorType::get(numElements, toBroadcast.getType());
144 auto llvmVectorType = typeConverter.convertType(vectorType);
145 auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
146 Value broadcasted = rewriter.create<LLVM::UndefOp>(loc, llvmVectorType);
147 for (unsigned i = 0; i < numElements; ++i) {
148 auto index = rewriter.create<LLVM::ConstantOp>(
149 loc, llvmI32Type, rewriter.getI32IntegerAttr(i));
150 broadcasted = rewriter.create<LLVM::InsertElementOp>(
151 loc, llvmVectorType, broadcasted, toBroadcast, index);
152 }
153 return broadcasted;
154 }
155
156 /// Broadcasts the value. If `srcType` is a scalar, the value remains unchanged.
optionallyBroadcast(Location loc,Value value,Type srcType,LLVMTypeConverter & typeConverter,ConversionPatternRewriter & rewriter)157 static Value optionallyBroadcast(Location loc, Value value, Type srcType,
158 LLVMTypeConverter &typeConverter,
159 ConversionPatternRewriter &rewriter) {
160 if (auto vectorType = srcType.dyn_cast<VectorType>()) {
161 unsigned numElements = vectorType.getNumElements();
162 return broadcast(loc, value, numElements, typeConverter, rewriter);
163 }
164 return value;
165 }
166
167 /// Utility function for bitfield ops: `BitFieldInsert`, `BitFieldSExtract` and
168 /// `BitFieldUExtract`.
169 /// Broadcast `Offset` and `Count` to match the type of `Base`. If `Base` is of
170 /// a vector type, construct a vector that has:
171 /// - same number of elements as `Base`
172 /// - each element has the type that is the same as the type of `Offset` or
173 /// `Count`
174 /// - each element has the same value as `Offset` or `Count`
175 /// Then cast `Offset` and `Count` if their bit width is different
176 /// from `Base` bit width.
processCountOrOffset(Location loc,Value value,Type srcType,Type dstType,LLVMTypeConverter & converter,ConversionPatternRewriter & rewriter)177 static Value processCountOrOffset(Location loc, Value value, Type srcType,
178 Type dstType, LLVMTypeConverter &converter,
179 ConversionPatternRewriter &rewriter) {
180 Value broadcasted =
181 optionallyBroadcast(loc, value, srcType, converter, rewriter);
182 return optionallyTruncateOrExtend(loc, broadcasted, dstType, rewriter);
183 }
184
185 /// Converts SPIR-V struct with a regular (according to `VulkanLayoutUtils`)
186 /// offset to LLVM struct. Otherwise, the conversion is not supported.
187 static Optional<Type>
convertStructTypeWithOffset(spirv::StructType type,LLVMTypeConverter & converter)188 convertStructTypeWithOffset(spirv::StructType type,
189 LLVMTypeConverter &converter) {
190 if (type != VulkanLayoutUtils::decorateType(type))
191 return llvm::None;
192
193 auto elementsVector = llvm::to_vector<8>(
194 llvm::map_range(type.getElementTypes(), [&](Type elementType) {
195 return converter.convertType(elementType);
196 }));
197 return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
198 /*isPacked=*/false);
199 }
200
201 /// Converts SPIR-V struct with no offset to packed LLVM struct.
convertStructTypePacked(spirv::StructType type,LLVMTypeConverter & converter)202 static Type convertStructTypePacked(spirv::StructType type,
203 LLVMTypeConverter &converter) {
204 auto elementsVector = llvm::to_vector<8>(
205 llvm::map_range(type.getElementTypes(), [&](Type elementType) {
206 return converter.convertType(elementType);
207 }));
208 return LLVM::LLVMStructType::getLiteral(type.getContext(), elementsVector,
209 /*isPacked=*/true);
210 }
211
212 /// Creates LLVM dialect constant with the given value.
createI32ConstantOf(Location loc,PatternRewriter & rewriter,unsigned value)213 static Value createI32ConstantOf(Location loc, PatternRewriter &rewriter,
214 unsigned value) {
215 return rewriter.create<LLVM::ConstantOp>(
216 loc, IntegerType::get(rewriter.getContext(), 32),
217 rewriter.getIntegerAttr(rewriter.getI32Type(), value));
218 }
219
220 /// Utility for `spv.Load` and `spv.Store` conversion.
replaceWithLoadOrStore(Operation * op,ConversionPatternRewriter & rewriter,LLVMTypeConverter & typeConverter,unsigned alignment,bool isVolatile,bool isNonTemporal)221 static LogicalResult replaceWithLoadOrStore(Operation *op,
222 ConversionPatternRewriter &rewriter,
223 LLVMTypeConverter &typeConverter,
224 unsigned alignment, bool isVolatile,
225 bool isNonTemporal) {
226 if (auto loadOp = dyn_cast<spirv::LoadOp>(op)) {
227 auto dstType = typeConverter.convertType(loadOp.getType());
228 if (!dstType)
229 return failure();
230 rewriter.replaceOpWithNewOp<LLVM::LoadOp>(
231 loadOp, dstType, loadOp.ptr(), alignment, isVolatile, isNonTemporal);
232 return success();
233 }
234 auto storeOp = cast<spirv::StoreOp>(op);
235 rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, storeOp.value(),
236 storeOp.ptr(), alignment,
237 isVolatile, isNonTemporal);
238 return success();
239 }
240
241 //===----------------------------------------------------------------------===//
242 // Type conversion
243 //===----------------------------------------------------------------------===//
244
245 /// Converts SPIR-V array type to LLVM array. Natural stride (according to
246 /// `VulkanLayoutUtils`) is also mapped to LLVM array. This has to be respected
247 /// when converting ops that manipulate array types.
convertArrayType(spirv::ArrayType type,TypeConverter & converter)248 static Optional<Type> convertArrayType(spirv::ArrayType type,
249 TypeConverter &converter) {
250 unsigned stride = type.getArrayStride();
251 Type elementType = type.getElementType();
252 auto sizeInBytes = elementType.cast<spirv::SPIRVType>().getSizeInBytes();
253 if (stride != 0 &&
254 !(sizeInBytes.hasValue() && sizeInBytes.getValue() == stride))
255 return llvm::None;
256
257 auto llvmElementType = converter.convertType(elementType);
258 unsigned numElements = type.getNumElements();
259 return LLVM::LLVMArrayType::get(llvmElementType, numElements);
260 }
261
262 /// Converts SPIR-V pointer type to LLVM pointer. Pointer's storage class is not
263 /// modelled at the moment.
convertPointerType(spirv::PointerType type,TypeConverter & converter)264 static Type convertPointerType(spirv::PointerType type,
265 TypeConverter &converter) {
266 auto pointeeType = converter.convertType(type.getPointeeType());
267 return LLVM::LLVMPointerType::get(pointeeType);
268 }
269
270 /// Converts SPIR-V runtime array to LLVM array. Since LLVM allows indexing over
271 /// the bounds, the runtime array is converted to a 0-sized LLVM array. There is
272 /// no modelling of array stride at the moment.
convertRuntimeArrayType(spirv::RuntimeArrayType type,TypeConverter & converter)273 static Optional<Type> convertRuntimeArrayType(spirv::RuntimeArrayType type,
274 TypeConverter &converter) {
275 if (type.getArrayStride() != 0)
276 return llvm::None;
277 auto elementType = converter.convertType(type.getElementType());
278 return LLVM::LLVMArrayType::get(elementType, 0);
279 }
280
281 /// Converts SPIR-V struct to LLVM struct. There is no support of structs with
282 /// member decorations. Also, only natural offset is supported.
convertStructType(spirv::StructType type,LLVMTypeConverter & converter)283 static Optional<Type> convertStructType(spirv::StructType type,
284 LLVMTypeConverter &converter) {
285 SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
286 type.getMemberDecorations(memberDecorations);
287 if (!memberDecorations.empty())
288 return llvm::None;
289 if (type.hasOffset())
290 return convertStructTypeWithOffset(type, converter);
291 return convertStructTypePacked(type, converter);
292 }
293
294 //===----------------------------------------------------------------------===//
295 // Operation conversion
296 //===----------------------------------------------------------------------===//
297
298 namespace {
299
300 class AccessChainPattern : public SPIRVToLLVMConversion<spirv::AccessChainOp> {
301 public:
302 using SPIRVToLLVMConversion<spirv::AccessChainOp>::SPIRVToLLVMConversion;
303
304 LogicalResult
matchAndRewrite(spirv::AccessChainOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const305 matchAndRewrite(spirv::AccessChainOp op, ArrayRef<Value> operands,
306 ConversionPatternRewriter &rewriter) const override {
307 auto dstType = typeConverter.convertType(op.component_ptr().getType());
308 if (!dstType)
309 return failure();
310 // To use GEP we need to add a first 0 index to go through the pointer.
311 auto indices = llvm::to_vector<4>(op.indices());
312 Type indexType = op.indices().front().getType();
313 auto llvmIndexType = typeConverter.convertType(indexType);
314 if (!llvmIndexType)
315 return failure();
316 Value zero = rewriter.create<LLVM::ConstantOp>(
317 op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0));
318 indices.insert(indices.begin(), zero);
319 rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, dstType, op.base_ptr(),
320 indices);
321 return success();
322 }
323 };
324
325 class AddressOfPattern : public SPIRVToLLVMConversion<spirv::AddressOfOp> {
326 public:
327 using SPIRVToLLVMConversion<spirv::AddressOfOp>::SPIRVToLLVMConversion;
328
329 LogicalResult
matchAndRewrite(spirv::AddressOfOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const330 matchAndRewrite(spirv::AddressOfOp op, ArrayRef<Value> operands,
331 ConversionPatternRewriter &rewriter) const override {
332 auto dstType = typeConverter.convertType(op.pointer().getType());
333 if (!dstType)
334 return failure();
335 rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, dstType, op.variable());
336 return success();
337 }
338 };
339
340 class BitFieldInsertPattern
341 : public SPIRVToLLVMConversion<spirv::BitFieldInsertOp> {
342 public:
343 using SPIRVToLLVMConversion<spirv::BitFieldInsertOp>::SPIRVToLLVMConversion;
344
345 LogicalResult
matchAndRewrite(spirv::BitFieldInsertOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const346 matchAndRewrite(spirv::BitFieldInsertOp op, ArrayRef<Value> operands,
347 ConversionPatternRewriter &rewriter) const override {
348 auto srcType = op.getType();
349 auto dstType = typeConverter.convertType(srcType);
350 if (!dstType)
351 return failure();
352 Location loc = op.getLoc();
353
354 // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
355 Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType,
356 typeConverter, rewriter);
357 Value count = processCountOrOffset(loc, op.count(), srcType, dstType,
358 typeConverter, rewriter);
359
360 // Create a mask with bits set outside [Offset, Offset + Count - 1].
361 Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
362 Value maskShiftedByCount =
363 rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
364 Value negated = rewriter.create<LLVM::XOrOp>(loc, dstType,
365 maskShiftedByCount, minusOne);
366 Value maskShiftedByCountAndOffset =
367 rewriter.create<LLVM::ShlOp>(loc, dstType, negated, offset);
368 Value mask = rewriter.create<LLVM::XOrOp>(
369 loc, dstType, maskShiftedByCountAndOffset, minusOne);
370
371 // Extract unchanged bits from the `Base` that are outside of
372 // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`.
373 Value baseAndMask =
374 rewriter.create<LLVM::AndOp>(loc, dstType, op.base(), mask);
375 Value insertShiftedByOffset =
376 rewriter.create<LLVM::ShlOp>(loc, dstType, op.insert(), offset);
377 rewriter.replaceOpWithNewOp<LLVM::OrOp>(op, dstType, baseAndMask,
378 insertShiftedByOffset);
379 return success();
380 }
381 };
382
383 /// Converts SPIR-V ConstantOp with scalar or vector type.
384 class ConstantScalarAndVectorPattern
385 : public SPIRVToLLVMConversion<spirv::ConstantOp> {
386 public:
387 using SPIRVToLLVMConversion<spirv::ConstantOp>::SPIRVToLLVMConversion;
388
389 LogicalResult
matchAndRewrite(spirv::ConstantOp constOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const390 matchAndRewrite(spirv::ConstantOp constOp, ArrayRef<Value> operands,
391 ConversionPatternRewriter &rewriter) const override {
392 auto srcType = constOp.getType();
393 if (!srcType.isa<VectorType>() && !srcType.isIntOrFloat())
394 return failure();
395
396 auto dstType = typeConverter.convertType(srcType);
397 if (!dstType)
398 return failure();
399
400 // SPIR-V constant can be a signed/unsigned integer, which has to be
401 // casted to signless integer when converting to LLVM dialect. Removing the
402 // sign bit may have unexpected behaviour. However, it is better to handle
403 // it case-by-case, given that the purpose of the conversion is not to
404 // cover all possible corner cases.
405 if (isSignedIntegerOrVector(srcType) ||
406 isUnsignedIntegerOrVector(srcType)) {
407 auto signlessType = rewriter.getIntegerType(getBitWidth(srcType));
408
409 if (srcType.isa<VectorType>()) {
410 auto dstElementsAttr = constOp.value().cast<DenseIntElementsAttr>();
411 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
412 constOp, dstType,
413 dstElementsAttr.mapValues(
414 signlessType, [&](const APInt &value) { return value; }));
415 return success();
416 }
417 auto srcAttr = constOp.value().cast<IntegerAttr>();
418 auto dstAttr = rewriter.getIntegerAttr(signlessType, srcAttr.getValue());
419 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, dstAttr);
420 return success();
421 }
422 rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(constOp, dstType, operands,
423 constOp.getAttrs());
424 return success();
425 }
426 };
427
428 class BitFieldSExtractPattern
429 : public SPIRVToLLVMConversion<spirv::BitFieldSExtractOp> {
430 public:
431 using SPIRVToLLVMConversion<spirv::BitFieldSExtractOp>::SPIRVToLLVMConversion;
432
433 LogicalResult
matchAndRewrite(spirv::BitFieldSExtractOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const434 matchAndRewrite(spirv::BitFieldSExtractOp op, ArrayRef<Value> operands,
435 ConversionPatternRewriter &rewriter) const override {
436 auto srcType = op.getType();
437 auto dstType = typeConverter.convertType(srcType);
438 if (!dstType)
439 return failure();
440 Location loc = op.getLoc();
441
442 // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
443 Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType,
444 typeConverter, rewriter);
445 Value count = processCountOrOffset(loc, op.count(), srcType, dstType,
446 typeConverter, rewriter);
447
448 // Create a constant that holds the size of the `Base`.
449 IntegerType integerType;
450 if (auto vecType = srcType.dyn_cast<VectorType>())
451 integerType = vecType.getElementType().cast<IntegerType>();
452 else
453 integerType = srcType.cast<IntegerType>();
454
455 auto baseSize = rewriter.getIntegerAttr(integerType, getBitWidth(srcType));
456 Value size =
457 srcType.isa<VectorType>()
458 ? rewriter.create<LLVM::ConstantOp>(
459 loc, dstType,
460 SplatElementsAttr::get(srcType.cast<ShapedType>(), baseSize))
461 : rewriter.create<LLVM::ConstantOp>(loc, dstType, baseSize);
462
463 // Shift `Base` left by [sizeof(Base) - (Count + Offset)], so that the bit
464 // at Offset + Count - 1 is the most significant bit now.
465 Value countPlusOffset =
466 rewriter.create<LLVM::AddOp>(loc, dstType, count, offset);
467 Value amountToShiftLeft =
468 rewriter.create<LLVM::SubOp>(loc, dstType, size, countPlusOffset);
469 Value baseShiftedLeft = rewriter.create<LLVM::ShlOp>(
470 loc, dstType, op.base(), amountToShiftLeft);
471
472 // Shift the result right, filling the bits with the sign bit.
473 Value amountToShiftRight =
474 rewriter.create<LLVM::AddOp>(loc, dstType, offset, amountToShiftLeft);
475 rewriter.replaceOpWithNewOp<LLVM::AShrOp>(op, dstType, baseShiftedLeft,
476 amountToShiftRight);
477 return success();
478 }
479 };
480
481 class BitFieldUExtractPattern
482 : public SPIRVToLLVMConversion<spirv::BitFieldUExtractOp> {
483 public:
484 using SPIRVToLLVMConversion<spirv::BitFieldUExtractOp>::SPIRVToLLVMConversion;
485
486 LogicalResult
matchAndRewrite(spirv::BitFieldUExtractOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const487 matchAndRewrite(spirv::BitFieldUExtractOp op, ArrayRef<Value> operands,
488 ConversionPatternRewriter &rewriter) const override {
489 auto srcType = op.getType();
490 auto dstType = typeConverter.convertType(srcType);
491 if (!dstType)
492 return failure();
493 Location loc = op.getLoc();
494
495 // Process `Offset` and `Count`: broadcast and extend/truncate if needed.
496 Value offset = processCountOrOffset(loc, op.offset(), srcType, dstType,
497 typeConverter, rewriter);
498 Value count = processCountOrOffset(loc, op.count(), srcType, dstType,
499 typeConverter, rewriter);
500
501 // Create a mask with bits set at [0, Count - 1].
502 Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter);
503 Value maskShiftedByCount =
504 rewriter.create<LLVM::ShlOp>(loc, dstType, minusOne, count);
505 Value mask = rewriter.create<LLVM::XOrOp>(loc, dstType, maskShiftedByCount,
506 minusOne);
507
508 // Shift `Base` by `Offset` and apply the mask on it.
509 Value shiftedBase =
510 rewriter.create<LLVM::LShrOp>(loc, dstType, op.base(), offset);
511 rewriter.replaceOpWithNewOp<LLVM::AndOp>(op, dstType, shiftedBase, mask);
512 return success();
513 }
514 };
515
516 class BranchConversionPattern : public SPIRVToLLVMConversion<spirv::BranchOp> {
517 public:
518 using SPIRVToLLVMConversion<spirv::BranchOp>::SPIRVToLLVMConversion;
519
520 LogicalResult
matchAndRewrite(spirv::BranchOp branchOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const521 matchAndRewrite(spirv::BranchOp branchOp, ArrayRef<Value> operands,
522 ConversionPatternRewriter &rewriter) const override {
523 rewriter.replaceOpWithNewOp<LLVM::BrOp>(branchOp, operands,
524 branchOp.getTarget());
525 return success();
526 }
527 };
528
529 class BranchConditionalConversionPattern
530 : public SPIRVToLLVMConversion<spirv::BranchConditionalOp> {
531 public:
532 using SPIRVToLLVMConversion<
533 spirv::BranchConditionalOp>::SPIRVToLLVMConversion;
534
535 LogicalResult
matchAndRewrite(spirv::BranchConditionalOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const536 matchAndRewrite(spirv::BranchConditionalOp op, ArrayRef<Value> operands,
537 ConversionPatternRewriter &rewriter) const override {
538 // If branch weights exist, map them to 32-bit integer vector.
539 ElementsAttr branchWeights = nullptr;
540 if (auto weights = op.branch_weights()) {
541 VectorType weightType = VectorType::get(2, rewriter.getI32Type());
542 branchWeights =
543 DenseElementsAttr::get(weightType, weights.getValue().getValue());
544 }
545
546 rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
547 op, op.condition(), op.getTrueBlockArguments(),
548 op.getFalseBlockArguments(), branchWeights, op.getTrueBlock(),
549 op.getFalseBlock());
550 return success();
551 }
552 };
553
554 /// Converts `spv.CompositeExtract` to `llvm.extractvalue` if the container type
555 /// is an aggregate type (struct or array). Otherwise, converts to
556 /// `llvm.extractelement` that operates on vectors.
557 class CompositeExtractPattern
558 : public SPIRVToLLVMConversion<spirv::CompositeExtractOp> {
559 public:
560 using SPIRVToLLVMConversion<spirv::CompositeExtractOp>::SPIRVToLLVMConversion;
561
562 LogicalResult
matchAndRewrite(spirv::CompositeExtractOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const563 matchAndRewrite(spirv::CompositeExtractOp op, ArrayRef<Value> operands,
564 ConversionPatternRewriter &rewriter) const override {
565 auto dstType = this->typeConverter.convertType(op.getType());
566 if (!dstType)
567 return failure();
568
569 Type containerType = op.composite().getType();
570 if (containerType.isa<VectorType>()) {
571 Location loc = op.getLoc();
572 IntegerAttr value = op.indices()[0].cast<IntegerAttr>();
573 Value index = createI32ConstantOf(loc, rewriter, value.getInt());
574 rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
575 op, dstType, op.composite(), index);
576 return success();
577 }
578 rewriter.replaceOpWithNewOp<LLVM::ExtractValueOp>(
579 op, dstType, op.composite(), op.indices());
580 return success();
581 }
582 };
583
584 /// Converts `spv.CompositeInsert` to `llvm.insertvalue` if the container type
585 /// is an aggregate type (struct or array). Otherwise, converts to
586 /// `llvm.insertelement` that operates on vectors.
587 class CompositeInsertPattern
588 : public SPIRVToLLVMConversion<spirv::CompositeInsertOp> {
589 public:
590 using SPIRVToLLVMConversion<spirv::CompositeInsertOp>::SPIRVToLLVMConversion;
591
592 LogicalResult
matchAndRewrite(spirv::CompositeInsertOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const593 matchAndRewrite(spirv::CompositeInsertOp op, ArrayRef<Value> operands,
594 ConversionPatternRewriter &rewriter) const override {
595 auto dstType = this->typeConverter.convertType(op.getType());
596 if (!dstType)
597 return failure();
598
599 Type containerType = op.composite().getType();
600 if (containerType.isa<VectorType>()) {
601 Location loc = op.getLoc();
602 IntegerAttr value = op.indices()[0].cast<IntegerAttr>();
603 Value index = createI32ConstantOf(loc, rewriter, value.getInt());
604 rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
605 op, dstType, op.composite(), op.object(), index);
606 return success();
607 }
608 rewriter.replaceOpWithNewOp<LLVM::InsertValueOp>(
609 op, dstType, op.composite(), op.object(), op.indices());
610 return success();
611 }
612 };
613
614 /// Converts SPIR-V operations that have straightforward LLVM equivalent
615 /// into LLVM dialect operations.
616 template <typename SPIRVOp, typename LLVMOp>
617 class DirectConversionPattern : public SPIRVToLLVMConversion<SPIRVOp> {
618 public:
619 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
620
621 LogicalResult
matchAndRewrite(SPIRVOp operation,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const622 matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
623 ConversionPatternRewriter &rewriter) const override {
624 auto dstType = this->typeConverter.convertType(operation.getType());
625 if (!dstType)
626 return failure();
627 rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType, operands,
628 operation.getAttrs());
629 return success();
630 }
631 };
632
633 /// Converts `spv.ExecutionMode` into a global struct constant that holds
634 /// execution mode information.
635 class ExecutionModePattern
636 : public SPIRVToLLVMConversion<spirv::ExecutionModeOp> {
637 public:
638 using SPIRVToLLVMConversion<spirv::ExecutionModeOp>::SPIRVToLLVMConversion;
639
640 LogicalResult
matchAndRewrite(spirv::ExecutionModeOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const641 matchAndRewrite(spirv::ExecutionModeOp op, ArrayRef<Value> operands,
642 ConversionPatternRewriter &rewriter) const override {
643 // First, create the global struct's name that would be associated with
644 // this entry point's execution mode. We set it to be:
645 // __spv__{SPIR-V module name}_{function name}_execution_mode_info
646 ModuleOp module = op->getParentOfType<ModuleOp>();
647 std::string moduleName;
648 if (module.getName().hasValue())
649 moduleName = "_" + module.getName().getValue().str();
650 else
651 moduleName = "";
652 std::string executionModeInfoName = llvm::formatv(
653 "__spv_{0}_{1}_execution_mode_info", moduleName, op.fn().str());
654
655 MLIRContext *context = rewriter.getContext();
656 OpBuilder::InsertionGuard guard(rewriter);
657 rewriter.setInsertionPointToStart(module.getBody());
658
659 // Create a struct type, corresponding to the C struct below.
660 // struct {
661 // int32_t executionMode;
662 // int32_t values[]; // optional values
663 // };
664 auto llvmI32Type = IntegerType::get(context, 32);
665 SmallVector<Type, 2> fields;
666 fields.push_back(llvmI32Type);
667 ArrayAttr values = op.values();
668 if (!values.empty()) {
669 auto arrayType = LLVM::LLVMArrayType::get(llvmI32Type, values.size());
670 fields.push_back(arrayType);
671 }
672 auto structType = LLVM::LLVMStructType::getLiteral(context, fields);
673
674 // Create `llvm.mlir.global` with initializer region containing one block.
675 auto global = rewriter.create<LLVM::GlobalOp>(
676 UnknownLoc::get(context), structType, /*isConstant=*/true,
677 LLVM::Linkage::External, executionModeInfoName, Attribute());
678 Location loc = global.getLoc();
679 Region ®ion = global.getInitializerRegion();
680 Block *block = rewriter.createBlock(®ion);
681
682 // Initialize the struct and set the execution mode value.
683 rewriter.setInsertionPoint(block, block->begin());
684 Value structValue = rewriter.create<LLVM::UndefOp>(loc, structType);
685 IntegerAttr executionModeAttr = op.execution_modeAttr();
686 Value executionMode =
687 rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, executionModeAttr);
688 structValue = rewriter.create<LLVM::InsertValueOp>(
689 loc, structType, structValue, executionMode,
690 ArrayAttr::get({rewriter.getIntegerAttr(rewriter.getI32Type(), 0)},
691 context));
692
693 // Insert extra operands if they exist into execution mode info struct.
694 for (unsigned i = 0, e = values.size(); i < e; ++i) {
695 auto attr = values.getValue()[i];
696 Value entry = rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, attr);
697 structValue = rewriter.create<LLVM::InsertValueOp>(
698 loc, structType, structValue, entry,
699 ArrayAttr::get({rewriter.getIntegerAttr(rewriter.getI32Type(), 1),
700 rewriter.getIntegerAttr(rewriter.getI32Type(), i)},
701 context));
702 }
703 rewriter.create<LLVM::ReturnOp>(loc, ArrayRef<Value>({structValue}));
704 rewriter.eraseOp(op);
705 return success();
706 }
707 };
708
709 /// Converts `spv.globalVariable` to `llvm.mlir.global`. Note that SPIR-V global
710 /// returns a pointer, whereas in LLVM dialect the global holds an actual value.
711 /// This difference is handled by `spv.mlir.addressof` and
712 /// `llvm.mlir.addressof`ops that both return a pointer.
713 class GlobalVariablePattern
714 : public SPIRVToLLVMConversion<spirv::GlobalVariableOp> {
715 public:
716 using SPIRVToLLVMConversion<spirv::GlobalVariableOp>::SPIRVToLLVMConversion;
717
718 LogicalResult
matchAndRewrite(spirv::GlobalVariableOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const719 matchAndRewrite(spirv::GlobalVariableOp op, ArrayRef<Value> operands,
720 ConversionPatternRewriter &rewriter) const override {
721 // Currently, there is no support of initialization with a constant value in
722 // SPIR-V dialect. Specialization constants are not considered as well.
723 if (op.initializer())
724 return failure();
725
726 auto srcType = op.type().cast<spirv::PointerType>();
727 auto dstType = typeConverter.convertType(srcType.getPointeeType());
728 if (!dstType)
729 return failure();
730
731 // Limit conversion to the current invocation only or `StorageBuffer`
732 // required by SPIR-V runner.
733 // This is okay because multiple invocations are not supported yet.
734 auto storageClass = srcType.getStorageClass();
735 if (storageClass != spirv::StorageClass::Input &&
736 storageClass != spirv::StorageClass::Private &&
737 storageClass != spirv::StorageClass::Output &&
738 storageClass != spirv::StorageClass::StorageBuffer) {
739 return failure();
740 }
741
742 // LLVM dialect spec: "If the global value is a constant, storing into it is
743 // not allowed.". This corresponds to SPIR-V 'Input' storage class that is
744 // read-only.
745 bool isConstant = storageClass == spirv::StorageClass::Input;
746 // SPIR-V spec: "By default, functions and global variables are private to a
747 // module and cannot be accessed by other modules. However, a module may be
748 // written to export or import functions and global (module scope)
749 // variables.". Therefore, map 'Private' storage class to private linkage,
750 // 'Input' and 'Output' to external linkage.
751 auto linkage = storageClass == spirv::StorageClass::Private
752 ? LLVM::Linkage::Private
753 : LLVM::Linkage::External;
754 rewriter.replaceOpWithNewOp<LLVM::GlobalOp>(
755 op, dstType, isConstant, linkage, op.sym_name(), Attribute());
756 return success();
757 }
758 };
759
760 /// Converts SPIR-V cast ops that do not have straightforward LLVM
761 /// equivalent in LLVM dialect.
762 template <typename SPIRVOp, typename LLVMExtOp, typename LLVMTruncOp>
763 class IndirectCastPattern : public SPIRVToLLVMConversion<SPIRVOp> {
764 public:
765 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
766
767 LogicalResult
matchAndRewrite(SPIRVOp operation,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const768 matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
769 ConversionPatternRewriter &rewriter) const override {
770
771 Type fromType = operation.operand().getType();
772 Type toType = operation.getType();
773
774 auto dstType = this->typeConverter.convertType(toType);
775 if (!dstType)
776 return failure();
777
778 if (getBitWidth(fromType) < getBitWidth(toType)) {
779 rewriter.template replaceOpWithNewOp<LLVMExtOp>(operation, dstType,
780 operands);
781 return success();
782 }
783 if (getBitWidth(fromType) > getBitWidth(toType)) {
784 rewriter.template replaceOpWithNewOp<LLVMTruncOp>(operation, dstType,
785 operands);
786 return success();
787 }
788 return failure();
789 }
790 };
791
792 class FunctionCallPattern
793 : public SPIRVToLLVMConversion<spirv::FunctionCallOp> {
794 public:
795 using SPIRVToLLVMConversion<spirv::FunctionCallOp>::SPIRVToLLVMConversion;
796
797 LogicalResult
matchAndRewrite(spirv::FunctionCallOp callOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const798 matchAndRewrite(spirv::FunctionCallOp callOp, ArrayRef<Value> operands,
799 ConversionPatternRewriter &rewriter) const override {
800 if (callOp.getNumResults() == 0) {
801 rewriter.replaceOpWithNewOp<LLVM::CallOp>(callOp, llvm::None, operands,
802 callOp.getAttrs());
803 return success();
804 }
805
806 // Function returns a single result.
807 auto dstType = typeConverter.convertType(callOp.getType(0));
808 rewriter.replaceOpWithNewOp<LLVM::CallOp>(callOp, dstType, operands,
809 callOp.getAttrs());
810 return success();
811 }
812 };
813
814 /// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate"
815 template <typename SPIRVOp, LLVM::FCmpPredicate predicate>
816 class FComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
817 public:
818 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
819
820 LogicalResult
matchAndRewrite(SPIRVOp operation,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const821 matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
822 ConversionPatternRewriter &rewriter) const override {
823
824 auto dstType = this->typeConverter.convertType(operation.getType());
825 if (!dstType)
826 return failure();
827
828 rewriter.template replaceOpWithNewOp<LLVM::FCmpOp>(
829 operation, dstType,
830 rewriter.getI64IntegerAttr(static_cast<int64_t>(predicate)),
831 operation.operand1(), operation.operand2(),
832 LLVM::FMFAttr::get({}, operation.getContext()));
833 return success();
834 }
835 };
836
837 /// Converts SPIR-V integer comparisons to llvm.icmp "predicate"
838 template <typename SPIRVOp, LLVM::ICmpPredicate predicate>
839 class IComparePattern : public SPIRVToLLVMConversion<SPIRVOp> {
840 public:
841 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
842
843 LogicalResult
matchAndRewrite(SPIRVOp operation,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const844 matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
845 ConversionPatternRewriter &rewriter) const override {
846
847 auto dstType = this->typeConverter.convertType(operation.getType());
848 if (!dstType)
849 return failure();
850
851 rewriter.template replaceOpWithNewOp<LLVM::ICmpOp>(
852 operation, dstType,
853 rewriter.getI64IntegerAttr(static_cast<int64_t>(predicate)),
854 operation.operand1(), operation.operand2());
855 return success();
856 }
857 };
858
859 class InverseSqrtPattern
860 : public SPIRVToLLVMConversion<spirv::GLSLInverseSqrtOp> {
861 public:
862 using SPIRVToLLVMConversion<spirv::GLSLInverseSqrtOp>::SPIRVToLLVMConversion;
863
864 LogicalResult
matchAndRewrite(spirv::GLSLInverseSqrtOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const865 matchAndRewrite(spirv::GLSLInverseSqrtOp op, ArrayRef<Value> operands,
866 ConversionPatternRewriter &rewriter) const override {
867 auto srcType = op.getType();
868 auto dstType = typeConverter.convertType(srcType);
869 if (!dstType)
870 return failure();
871
872 Location loc = op.getLoc();
873 Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
874 Value sqrt = rewriter.create<LLVM::SqrtOp>(loc, dstType, op.operand());
875 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(op, dstType, one, sqrt);
876 return success();
877 }
878 };
879
880 /// Converts `spv.Load` and `spv.Store` to LLVM dialect.
881 template <typename SPIRVop>
882 class LoadStorePattern : public SPIRVToLLVMConversion<SPIRVop> {
883 public:
884 using SPIRVToLLVMConversion<SPIRVop>::SPIRVToLLVMConversion;
885
886 LogicalResult
matchAndRewrite(SPIRVop op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const887 matchAndRewrite(SPIRVop op, ArrayRef<Value> operands,
888 ConversionPatternRewriter &rewriter) const override {
889
890 if (!op.memory_access().hasValue()) {
891 replaceWithLoadOrStore(op, rewriter, this->typeConverter, /*alignment=*/0,
892 /*isVolatile=*/false, /*isNonTemporal=*/false);
893 return success();
894 }
895 auto memoryAccess = op.memory_access().getValue();
896 switch (memoryAccess) {
897 case spirv::MemoryAccess::Aligned:
898 case spirv::MemoryAccess::None:
899 case spirv::MemoryAccess::Nontemporal:
900 case spirv::MemoryAccess::Volatile: {
901 unsigned alignment =
902 memoryAccess == spirv::MemoryAccess::Aligned ? *op.alignment() : 0;
903 bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal;
904 bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile;
905 replaceWithLoadOrStore(op, rewriter, this->typeConverter, alignment,
906 isVolatile, isNonTemporal);
907 return success();
908 }
909 default:
910 // There is no support of other memory access attributes.
911 return failure();
912 }
913 }
914 };
915
916 /// Converts `spv.Not` and `spv.LogicalNot` into LLVM dialect.
917 template <typename SPIRVOp>
918 class NotPattern : public SPIRVToLLVMConversion<SPIRVOp> {
919 public:
920 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
921
922 LogicalResult
matchAndRewrite(SPIRVOp notOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const923 matchAndRewrite(SPIRVOp notOp, ArrayRef<Value> operands,
924 ConversionPatternRewriter &rewriter) const override {
925
926 auto srcType = notOp.getType();
927 auto dstType = this->typeConverter.convertType(srcType);
928 if (!dstType)
929 return failure();
930
931 Location loc = notOp.getLoc();
932 IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter);
933 auto mask = srcType.template isa<VectorType>()
934 ? rewriter.create<LLVM::ConstantOp>(
935 loc, dstType,
936 SplatElementsAttr::get(
937 srcType.template cast<VectorType>(), minusOne))
938 : rewriter.create<LLVM::ConstantOp>(loc, dstType, minusOne);
939 rewriter.template replaceOpWithNewOp<LLVM::XOrOp>(notOp, dstType,
940 notOp.operand(), mask);
941 return success();
942 }
943 };
944
945 /// A template pattern that erases the given `SPIRVOp`.
946 template <typename SPIRVOp>
947 class ErasePattern : public SPIRVToLLVMConversion<SPIRVOp> {
948 public:
949 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
950
951 LogicalResult
matchAndRewrite(SPIRVOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const952 matchAndRewrite(SPIRVOp op, ArrayRef<Value> operands,
953 ConversionPatternRewriter &rewriter) const override {
954 rewriter.eraseOp(op);
955 return success();
956 }
957 };
958
959 class ReturnPattern : public SPIRVToLLVMConversion<spirv::ReturnOp> {
960 public:
961 using SPIRVToLLVMConversion<spirv::ReturnOp>::SPIRVToLLVMConversion;
962
963 LogicalResult
matchAndRewrite(spirv::ReturnOp returnOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const964 matchAndRewrite(spirv::ReturnOp returnOp, ArrayRef<Value> operands,
965 ConversionPatternRewriter &rewriter) const override {
966 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp, ArrayRef<Type>(),
967 ArrayRef<Value>());
968 return success();
969 }
970 };
971
972 class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
973 public:
974 using SPIRVToLLVMConversion<spirv::ReturnValueOp>::SPIRVToLLVMConversion;
975
976 LogicalResult
matchAndRewrite(spirv::ReturnValueOp returnValueOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const977 matchAndRewrite(spirv::ReturnValueOp returnValueOp, ArrayRef<Value> operands,
978 ConversionPatternRewriter &rewriter) const override {
979 rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnValueOp, ArrayRef<Type>(),
980 operands);
981 return success();
982 }
983 };
984
985 /// Converts `spv.loop` to LLVM dialect. All blocks within selection should be
986 /// reachable for conversion to succeed.
987 /// The structure of the loop in LLVM dialect will be the following:
988 ///
989 /// +------------------------------------+
990 /// | <code before spv.loop> |
991 /// | llvm.br ^header |
992 /// +------------------------------------+
993 /// |
994 /// +----------------+ |
995 /// | | |
996 /// | V V
997 /// | +------------------------------------+
998 /// | | ^header: |
999 /// | | <header code> |
1000 /// | | llvm.cond_br %cond, ^body, ^exit |
1001 /// | +------------------------------------+
1002 /// | |
1003 /// | |----------------------+
1004 /// | | |
1005 /// | V |
1006 /// | +------------------------------------+ |
1007 /// | | ^body: | |
1008 /// | | <body code> | |
1009 /// | | llvm.br ^continue | |
1010 /// | +------------------------------------+ |
1011 /// | | |
1012 /// | V |
1013 /// | +------------------------------------+ |
1014 /// | | ^continue: | |
1015 /// | | <continue code> | |
1016 /// | | llvm.br ^header | |
1017 /// | +------------------------------------+ |
1018 /// | | |
1019 /// +---------------+ +----------------------+
1020 /// |
1021 /// V
1022 /// +------------------------------------+
1023 /// | ^exit: |
1024 /// | llvm.br ^remaining |
1025 /// +------------------------------------+
1026 /// |
1027 /// V
1028 /// +------------------------------------+
1029 /// | ^remaining: |
1030 /// | <code after spv.loop> |
1031 /// +------------------------------------+
1032 ///
1033 class LoopPattern : public SPIRVToLLVMConversion<spirv::LoopOp> {
1034 public:
1035 using SPIRVToLLVMConversion<spirv::LoopOp>::SPIRVToLLVMConversion;
1036
1037 LogicalResult
matchAndRewrite(spirv::LoopOp loopOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1038 matchAndRewrite(spirv::LoopOp loopOp, ArrayRef<Value> operands,
1039 ConversionPatternRewriter &rewriter) const override {
1040 // There is no support of loop control at the moment.
1041 if (loopOp.loop_control() != spirv::LoopControl::None)
1042 return failure();
1043
1044 Location loc = loopOp.getLoc();
1045
1046 // Split the current block after `spv.loop`. The remaining ops will be used
1047 // in `endBlock`.
1048 Block *currentBlock = rewriter.getBlock();
1049 auto position = Block::iterator(loopOp);
1050 Block *endBlock = rewriter.splitBlock(currentBlock, position);
1051
1052 // Remove entry block and create a branch in the current block going to the
1053 // header block.
1054 Block *entryBlock = loopOp.getEntryBlock();
1055 assert(entryBlock->getOperations().size() == 1);
1056 auto brOp = dyn_cast<spirv::BranchOp>(entryBlock->getOperations().front());
1057 if (!brOp)
1058 return failure();
1059 Block *headerBlock = loopOp.getHeaderBlock();
1060 rewriter.setInsertionPointToEnd(currentBlock);
1061 rewriter.create<LLVM::BrOp>(loc, brOp.getBlockArguments(), headerBlock);
1062 rewriter.eraseBlock(entryBlock);
1063
1064 // Branch from merge block to end block.
1065 Block *mergeBlock = loopOp.getMergeBlock();
1066 Operation *terminator = mergeBlock->getTerminator();
1067 ValueRange terminatorOperands = terminator->getOperands();
1068 rewriter.setInsertionPointToEnd(mergeBlock);
1069 rewriter.create<LLVM::BrOp>(loc, terminatorOperands, endBlock);
1070
1071 rewriter.inlineRegionBefore(loopOp.body(), endBlock);
1072 rewriter.replaceOp(loopOp, endBlock->getArguments());
1073 return success();
1074 }
1075 };
1076
1077 /// Converts `spv.selection` with `spv.BranchConditional` in its header block.
1078 /// All blocks within selection should be reachable for conversion to succeed.
1079 class SelectionPattern : public SPIRVToLLVMConversion<spirv::SelectionOp> {
1080 public:
1081 using SPIRVToLLVMConversion<spirv::SelectionOp>::SPIRVToLLVMConversion;
1082
1083 LogicalResult
matchAndRewrite(spirv::SelectionOp op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1084 matchAndRewrite(spirv::SelectionOp op, ArrayRef<Value> operands,
1085 ConversionPatternRewriter &rewriter) const override {
1086 // There is no support for `Flatten` or `DontFlatten` selection control at
1087 // the moment. This are just compiler hints and can be performed during the
1088 // optimization passes.
1089 if (op.selection_control() != spirv::SelectionControl::None)
1090 return failure();
1091
1092 // `spv.selection` should have at least two blocks: one selection header
1093 // block and one merge block. If no blocks are present, or control flow
1094 // branches straight to merge block (two blocks are present), the op is
1095 // redundant and it is erased.
1096 if (op.body().getBlocks().size() <= 2) {
1097 rewriter.eraseOp(op);
1098 return success();
1099 }
1100
1101 Location loc = op.getLoc();
1102
1103 // Split the current block after `spv.selection`. The remaining ops will be
1104 // used in `continueBlock`.
1105 auto *currentBlock = rewriter.getInsertionBlock();
1106 rewriter.setInsertionPointAfter(op);
1107 auto position = rewriter.getInsertionPoint();
1108 auto *continueBlock = rewriter.splitBlock(currentBlock, position);
1109
1110 // Extract conditional branch information from the header block. By SPIR-V
1111 // dialect spec, it should contain `spv.BranchConditional` or `spv.Switch`
1112 // op. Note that `spv.Switch op` is not supported at the moment in the
1113 // SPIR-V dialect. Remove this block when finished.
1114 auto *headerBlock = op.getHeaderBlock();
1115 assert(headerBlock->getOperations().size() == 1);
1116 auto condBrOp = dyn_cast<spirv::BranchConditionalOp>(
1117 headerBlock->getOperations().front());
1118 if (!condBrOp)
1119 return failure();
1120 rewriter.eraseBlock(headerBlock);
1121
1122 // Branch from merge block to continue block.
1123 auto *mergeBlock = op.getMergeBlock();
1124 Operation *terminator = mergeBlock->getTerminator();
1125 ValueRange terminatorOperands = terminator->getOperands();
1126 rewriter.setInsertionPointToEnd(mergeBlock);
1127 rewriter.create<LLVM::BrOp>(loc, terminatorOperands, continueBlock);
1128
1129 // Link current block to `true` and `false` blocks within the selection.
1130 Block *trueBlock = condBrOp.getTrueBlock();
1131 Block *falseBlock = condBrOp.getFalseBlock();
1132 rewriter.setInsertionPointToEnd(currentBlock);
1133 rewriter.create<LLVM::CondBrOp>(loc, condBrOp.condition(), trueBlock,
1134 condBrOp.trueTargetOperands(), falseBlock,
1135 condBrOp.falseTargetOperands());
1136
1137 rewriter.inlineRegionBefore(op.body(), continueBlock);
1138 rewriter.replaceOp(op, continueBlock->getArguments());
1139 return success();
1140 }
1141 };
1142
1143 /// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect
1144 /// puts a restriction on `Shift` and `Base` to have the same bit width,
1145 /// `Shift` is zero or sign extended to match this specification. Cases when
1146 /// `Shift` bit width > `Base` bit width are considered to be illegal.
1147 template <typename SPIRVOp, typename LLVMOp>
1148 class ShiftPattern : public SPIRVToLLVMConversion<SPIRVOp> {
1149 public:
1150 using SPIRVToLLVMConversion<SPIRVOp>::SPIRVToLLVMConversion;
1151
1152 LogicalResult
matchAndRewrite(SPIRVOp operation,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1153 matchAndRewrite(SPIRVOp operation, ArrayRef<Value> operands,
1154 ConversionPatternRewriter &rewriter) const override {
1155
1156 auto dstType = this->typeConverter.convertType(operation.getType());
1157 if (!dstType)
1158 return failure();
1159
1160 Type op1Type = operation.operand1().getType();
1161 Type op2Type = operation.operand2().getType();
1162
1163 if (op1Type == op2Type) {
1164 rewriter.template replaceOpWithNewOp<LLVMOp>(operation, dstType,
1165 operands);
1166 return success();
1167 }
1168
1169 Location loc = operation.getLoc();
1170 Value extended;
1171 if (isUnsignedIntegerOrVector(op2Type)) {
1172 extended = rewriter.template create<LLVM::ZExtOp>(loc, dstType,
1173 operation.operand2());
1174 } else {
1175 extended = rewriter.template create<LLVM::SExtOp>(loc, dstType,
1176 operation.operand2());
1177 }
1178 Value result = rewriter.template create<LLVMOp>(
1179 loc, dstType, operation.operand1(), extended);
1180 rewriter.replaceOp(operation, result);
1181 return success();
1182 }
1183 };
1184
1185 class TanPattern : public SPIRVToLLVMConversion<spirv::GLSLTanOp> {
1186 public:
1187 using SPIRVToLLVMConversion<spirv::GLSLTanOp>::SPIRVToLLVMConversion;
1188
1189 LogicalResult
matchAndRewrite(spirv::GLSLTanOp tanOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1190 matchAndRewrite(spirv::GLSLTanOp tanOp, ArrayRef<Value> operands,
1191 ConversionPatternRewriter &rewriter) const override {
1192 auto dstType = typeConverter.convertType(tanOp.getType());
1193 if (!dstType)
1194 return failure();
1195
1196 Location loc = tanOp.getLoc();
1197 Value sin = rewriter.create<LLVM::SinOp>(loc, dstType, tanOp.operand());
1198 Value cos = rewriter.create<LLVM::CosOp>(loc, dstType, tanOp.operand());
1199 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanOp, dstType, sin, cos);
1200 return success();
1201 }
1202 };
1203
1204 /// Convert `spv.Tanh` to
1205 ///
1206 /// exp(2x) - 1
1207 /// -----------
1208 /// exp(2x) + 1
1209 ///
1210 class TanhPattern : public SPIRVToLLVMConversion<spirv::GLSLTanhOp> {
1211 public:
1212 using SPIRVToLLVMConversion<spirv::GLSLTanhOp>::SPIRVToLLVMConversion;
1213
1214 LogicalResult
matchAndRewrite(spirv::GLSLTanhOp tanhOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1215 matchAndRewrite(spirv::GLSLTanhOp tanhOp, ArrayRef<Value> operands,
1216 ConversionPatternRewriter &rewriter) const override {
1217 auto srcType = tanhOp.getType();
1218 auto dstType = typeConverter.convertType(srcType);
1219 if (!dstType)
1220 return failure();
1221
1222 Location loc = tanhOp.getLoc();
1223 Value two = createFPConstant(loc, srcType, dstType, rewriter, 2.0);
1224 Value multiplied =
1225 rewriter.create<LLVM::FMulOp>(loc, dstType, two, tanhOp.operand());
1226 Value exponential = rewriter.create<LLVM::ExpOp>(loc, dstType, multiplied);
1227 Value one = createFPConstant(loc, srcType, dstType, rewriter, 1.0);
1228 Value numerator =
1229 rewriter.create<LLVM::FSubOp>(loc, dstType, exponential, one);
1230 Value denominator =
1231 rewriter.create<LLVM::FAddOp>(loc, dstType, exponential, one);
1232 rewriter.replaceOpWithNewOp<LLVM::FDivOp>(tanhOp, dstType, numerator,
1233 denominator);
1234 return success();
1235 }
1236 };
1237
1238 class VariablePattern : public SPIRVToLLVMConversion<spirv::VariableOp> {
1239 public:
1240 using SPIRVToLLVMConversion<spirv::VariableOp>::SPIRVToLLVMConversion;
1241
1242 LogicalResult
matchAndRewrite(spirv::VariableOp varOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1243 matchAndRewrite(spirv::VariableOp varOp, ArrayRef<Value> operands,
1244 ConversionPatternRewriter &rewriter) const override {
1245 auto srcType = varOp.getType();
1246 // Initialization is supported for scalars and vectors only.
1247 auto pointerTo = srcType.cast<spirv::PointerType>().getPointeeType();
1248 auto init = varOp.initializer();
1249 if (init && !pointerTo.isIntOrFloat() && !pointerTo.isa<VectorType>())
1250 return failure();
1251
1252 auto dstType = typeConverter.convertType(srcType);
1253 if (!dstType)
1254 return failure();
1255
1256 Location loc = varOp.getLoc();
1257 Value size = createI32ConstantOf(loc, rewriter, 1);
1258 if (!init) {
1259 rewriter.replaceOpWithNewOp<LLVM::AllocaOp>(varOp, dstType, size);
1260 return success();
1261 }
1262 Value allocated = rewriter.create<LLVM::AllocaOp>(loc, dstType, size);
1263 rewriter.create<LLVM::StoreOp>(loc, init, allocated);
1264 rewriter.replaceOp(varOp, allocated);
1265 return success();
1266 }
1267 };
1268
1269 //===----------------------------------------------------------------------===//
1270 // FuncOp conversion
1271 //===----------------------------------------------------------------------===//
1272
1273 class FuncConversionPattern : public SPIRVToLLVMConversion<spirv::FuncOp> {
1274 public:
1275 using SPIRVToLLVMConversion<spirv::FuncOp>::SPIRVToLLVMConversion;
1276
1277 LogicalResult
matchAndRewrite(spirv::FuncOp funcOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1278 matchAndRewrite(spirv::FuncOp funcOp, ArrayRef<Value> operands,
1279 ConversionPatternRewriter &rewriter) const override {
1280
1281 // Convert function signature. At the moment LLVMType converter is enough
1282 // for currently supported types.
1283 auto funcType = funcOp.getType();
1284 TypeConverter::SignatureConversion signatureConverter(
1285 funcType.getNumInputs());
1286 auto llvmType = typeConverter.convertFunctionSignature(
1287 funcOp.getType(), /*isVariadic=*/false, signatureConverter);
1288 if (!llvmType)
1289 return failure();
1290
1291 // Create a new `LLVMFuncOp`
1292 Location loc = funcOp.getLoc();
1293 StringRef name = funcOp.getName();
1294 auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(loc, name, llvmType);
1295
1296 // Convert SPIR-V Function Control to equivalent LLVM function attribute
1297 MLIRContext *context = funcOp.getContext();
1298 switch (funcOp.function_control()) {
1299 #define DISPATCH(functionControl, llvmAttr) \
1300 case functionControl: \
1301 newFuncOp->setAttr("passthrough", ArrayAttr::get({llvmAttr}, context)); \
1302 break;
1303
1304 DISPATCH(spirv::FunctionControl::Inline,
1305 StringAttr::get("alwaysinline", context));
1306 DISPATCH(spirv::FunctionControl::DontInline,
1307 StringAttr::get("noinline", context));
1308 DISPATCH(spirv::FunctionControl::Pure,
1309 StringAttr::get("readonly", context));
1310 DISPATCH(spirv::FunctionControl::Const,
1311 StringAttr::get("readnone", context));
1312
1313 #undef DISPATCH
1314
1315 // Default: if `spirv::FunctionControl::None`, then no attributes are
1316 // needed.
1317 default:
1318 break;
1319 }
1320
1321 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
1322 newFuncOp.end());
1323 if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
1324 &signatureConverter))) {
1325 return failure();
1326 }
1327 rewriter.eraseOp(funcOp);
1328 return success();
1329 }
1330 };
1331
1332 //===----------------------------------------------------------------------===//
1333 // ModuleOp conversion
1334 //===----------------------------------------------------------------------===//
1335
1336 class ModuleConversionPattern : public SPIRVToLLVMConversion<spirv::ModuleOp> {
1337 public:
1338 using SPIRVToLLVMConversion<spirv::ModuleOp>::SPIRVToLLVMConversion;
1339
1340 LogicalResult
matchAndRewrite(spirv::ModuleOp spvModuleOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1341 matchAndRewrite(spirv::ModuleOp spvModuleOp, ArrayRef<Value> operands,
1342 ConversionPatternRewriter &rewriter) const override {
1343
1344 auto newModuleOp =
1345 rewriter.create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName());
1346 rewriter.inlineRegionBefore(spvModuleOp.body(), newModuleOp.getBody());
1347
1348 // Remove the terminator block that was automatically added by builder
1349 rewriter.eraseBlock(&newModuleOp.getBodyRegion().back());
1350 rewriter.eraseOp(spvModuleOp);
1351 return success();
1352 }
1353 };
1354
1355 class ModuleEndConversionPattern
1356 : public SPIRVToLLVMConversion<spirv::ModuleEndOp> {
1357 public:
1358 using SPIRVToLLVMConversion<spirv::ModuleEndOp>::SPIRVToLLVMConversion;
1359
1360 LogicalResult
matchAndRewrite(spirv::ModuleEndOp moduleEndOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const1361 matchAndRewrite(spirv::ModuleEndOp moduleEndOp, ArrayRef<Value> operands,
1362 ConversionPatternRewriter &rewriter) const override {
1363
1364 rewriter.replaceOpWithNewOp<ModuleTerminatorOp>(moduleEndOp);
1365 return success();
1366 }
1367 };
1368
1369 } // namespace
1370
1371 //===----------------------------------------------------------------------===//
1372 // Pattern population
1373 //===----------------------------------------------------------------------===//
1374
populateSPIRVToLLVMTypeConversion(LLVMTypeConverter & typeConverter)1375 void mlir::populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter) {
1376 typeConverter.addConversion([&](spirv::ArrayType type) {
1377 return convertArrayType(type, typeConverter);
1378 });
1379 typeConverter.addConversion([&](spirv::PointerType type) {
1380 return convertPointerType(type, typeConverter);
1381 });
1382 typeConverter.addConversion([&](spirv::RuntimeArrayType type) {
1383 return convertRuntimeArrayType(type, typeConverter);
1384 });
1385 typeConverter.addConversion([&](spirv::StructType type) {
1386 return convertStructType(type, typeConverter);
1387 });
1388 }
1389
populateSPIRVToLLVMConversionPatterns(MLIRContext * context,LLVMTypeConverter & typeConverter,OwningRewritePatternList & patterns)1390 void mlir::populateSPIRVToLLVMConversionPatterns(
1391 MLIRContext *context, LLVMTypeConverter &typeConverter,
1392 OwningRewritePatternList &patterns) {
1393 patterns.insert<
1394 // Arithmetic ops
1395 DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
1396 DirectConversionPattern<spirv::IMulOp, LLVM::MulOp>,
1397 DirectConversionPattern<spirv::ISubOp, LLVM::SubOp>,
1398 DirectConversionPattern<spirv::FAddOp, LLVM::FAddOp>,
1399 DirectConversionPattern<spirv::FDivOp, LLVM::FDivOp>,
1400 DirectConversionPattern<spirv::FMulOp, LLVM::FMulOp>,
1401 DirectConversionPattern<spirv::FNegateOp, LLVM::FNegOp>,
1402 DirectConversionPattern<spirv::FRemOp, LLVM::FRemOp>,
1403 DirectConversionPattern<spirv::FSubOp, LLVM::FSubOp>,
1404 DirectConversionPattern<spirv::SDivOp, LLVM::SDivOp>,
1405 DirectConversionPattern<spirv::SRemOp, LLVM::SRemOp>,
1406 DirectConversionPattern<spirv::UDivOp, LLVM::UDivOp>,
1407 DirectConversionPattern<spirv::UModOp, LLVM::URemOp>,
1408
1409 // Bitwise ops
1410 BitFieldInsertPattern, BitFieldUExtractPattern, BitFieldSExtractPattern,
1411 DirectConversionPattern<spirv::BitCountOp, LLVM::CtPopOp>,
1412 DirectConversionPattern<spirv::BitReverseOp, LLVM::BitReverseOp>,
1413 DirectConversionPattern<spirv::BitwiseAndOp, LLVM::AndOp>,
1414 DirectConversionPattern<spirv::BitwiseOrOp, LLVM::OrOp>,
1415 DirectConversionPattern<spirv::BitwiseXorOp, LLVM::XOrOp>,
1416 NotPattern<spirv::NotOp>,
1417
1418 // Cast ops
1419 DirectConversionPattern<spirv::BitcastOp, LLVM::BitcastOp>,
1420 DirectConversionPattern<spirv::ConvertFToSOp, LLVM::FPToSIOp>,
1421 DirectConversionPattern<spirv::ConvertFToUOp, LLVM::FPToUIOp>,
1422 DirectConversionPattern<spirv::ConvertSToFOp, LLVM::SIToFPOp>,
1423 DirectConversionPattern<spirv::ConvertUToFOp, LLVM::UIToFPOp>,
1424 IndirectCastPattern<spirv::FConvertOp, LLVM::FPExtOp, LLVM::FPTruncOp>,
1425 IndirectCastPattern<spirv::SConvertOp, LLVM::SExtOp, LLVM::TruncOp>,
1426 IndirectCastPattern<spirv::UConvertOp, LLVM::ZExtOp, LLVM::TruncOp>,
1427
1428 // Comparison ops
1429 IComparePattern<spirv::IEqualOp, LLVM::ICmpPredicate::eq>,
1430 IComparePattern<spirv::INotEqualOp, LLVM::ICmpPredicate::ne>,
1431 FComparePattern<spirv::FOrdEqualOp, LLVM::FCmpPredicate::oeq>,
1432 FComparePattern<spirv::FOrdGreaterThanOp, LLVM::FCmpPredicate::ogt>,
1433 FComparePattern<spirv::FOrdGreaterThanEqualOp, LLVM::FCmpPredicate::oge>,
1434 FComparePattern<spirv::FOrdLessThanEqualOp, LLVM::FCmpPredicate::ole>,
1435 FComparePattern<spirv::FOrdLessThanOp, LLVM::FCmpPredicate::olt>,
1436 FComparePattern<spirv::FOrdNotEqualOp, LLVM::FCmpPredicate::one>,
1437 FComparePattern<spirv::FUnordEqualOp, LLVM::FCmpPredicate::ueq>,
1438 FComparePattern<spirv::FUnordGreaterThanOp, LLVM::FCmpPredicate::ugt>,
1439 FComparePattern<spirv::FUnordGreaterThanEqualOp,
1440 LLVM::FCmpPredicate::uge>,
1441 FComparePattern<spirv::FUnordLessThanEqualOp, LLVM::FCmpPredicate::ule>,
1442 FComparePattern<spirv::FUnordLessThanOp, LLVM::FCmpPredicate::ult>,
1443 FComparePattern<spirv::FUnordNotEqualOp, LLVM::FCmpPredicate::une>,
1444 IComparePattern<spirv::SGreaterThanOp, LLVM::ICmpPredicate::sgt>,
1445 IComparePattern<spirv::SGreaterThanEqualOp, LLVM::ICmpPredicate::sge>,
1446 IComparePattern<spirv::SLessThanEqualOp, LLVM::ICmpPredicate::sle>,
1447 IComparePattern<spirv::SLessThanOp, LLVM::ICmpPredicate::slt>,
1448 IComparePattern<spirv::UGreaterThanOp, LLVM::ICmpPredicate::ugt>,
1449 IComparePattern<spirv::UGreaterThanEqualOp, LLVM::ICmpPredicate::uge>,
1450 IComparePattern<spirv::ULessThanEqualOp, LLVM::ICmpPredicate::ule>,
1451 IComparePattern<spirv::ULessThanOp, LLVM::ICmpPredicate::ult>,
1452
1453 // Constant op
1454 ConstantScalarAndVectorPattern,
1455
1456 // Control Flow ops
1457 BranchConversionPattern, BranchConditionalConversionPattern,
1458 FunctionCallPattern, LoopPattern, SelectionPattern,
1459 ErasePattern<spirv::MergeOp>,
1460
1461 // Entry points and execution mode are handled separately.
1462 ErasePattern<spirv::EntryPointOp>, ExecutionModePattern,
1463
1464 // GLSL extended instruction set ops
1465 DirectConversionPattern<spirv::GLSLCeilOp, LLVM::FCeilOp>,
1466 DirectConversionPattern<spirv::GLSLCosOp, LLVM::CosOp>,
1467 DirectConversionPattern<spirv::GLSLExpOp, LLVM::ExpOp>,
1468 DirectConversionPattern<spirv::GLSLFAbsOp, LLVM::FAbsOp>,
1469 DirectConversionPattern<spirv::GLSLFloorOp, LLVM::FFloorOp>,
1470 DirectConversionPattern<spirv::GLSLFMaxOp, LLVM::MaxNumOp>,
1471 DirectConversionPattern<spirv::GLSLFMinOp, LLVM::MinNumOp>,
1472 DirectConversionPattern<spirv::GLSLLogOp, LLVM::LogOp>,
1473 DirectConversionPattern<spirv::GLSLSinOp, LLVM::SinOp>,
1474 DirectConversionPattern<spirv::GLSLSMaxOp, LLVM::SMaxOp>,
1475 DirectConversionPattern<spirv::GLSLSMinOp, LLVM::SMinOp>,
1476 DirectConversionPattern<spirv::GLSLSqrtOp, LLVM::SqrtOp>,
1477 InverseSqrtPattern, TanPattern, TanhPattern,
1478
1479 // Logical ops
1480 DirectConversionPattern<spirv::LogicalAndOp, LLVM::AndOp>,
1481 DirectConversionPattern<spirv::LogicalOrOp, LLVM::OrOp>,
1482 IComparePattern<spirv::LogicalEqualOp, LLVM::ICmpPredicate::eq>,
1483 IComparePattern<spirv::LogicalNotEqualOp, LLVM::ICmpPredicate::ne>,
1484 NotPattern<spirv::LogicalNotOp>,
1485
1486 // Memory ops
1487 AccessChainPattern, AddressOfPattern, GlobalVariablePattern,
1488 LoadStorePattern<spirv::LoadOp>, LoadStorePattern<spirv::StoreOp>,
1489 VariablePattern,
1490
1491 // Miscellaneous ops
1492 CompositeExtractPattern, CompositeInsertPattern,
1493 DirectConversionPattern<spirv::SelectOp, LLVM::SelectOp>,
1494 DirectConversionPattern<spirv::UndefOp, LLVM::UndefOp>,
1495
1496 // Shift ops
1497 ShiftPattern<spirv::ShiftRightArithmeticOp, LLVM::AShrOp>,
1498 ShiftPattern<spirv::ShiftRightLogicalOp, LLVM::LShrOp>,
1499 ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
1500
1501 // Return ops
1502 ReturnPattern, ReturnValuePattern>(context, typeConverter);
1503 }
1504
populateSPIRVToLLVMFunctionConversionPatterns(MLIRContext * context,LLVMTypeConverter & typeConverter,OwningRewritePatternList & patterns)1505 void mlir::populateSPIRVToLLVMFunctionConversionPatterns(
1506 MLIRContext *context, LLVMTypeConverter &typeConverter,
1507 OwningRewritePatternList &patterns) {
1508 patterns.insert<FuncConversionPattern>(context, typeConverter);
1509 }
1510
populateSPIRVToLLVMModuleConversionPatterns(MLIRContext * context,LLVMTypeConverter & typeConverter,OwningRewritePatternList & patterns)1511 void mlir::populateSPIRVToLLVMModuleConversionPatterns(
1512 MLIRContext *context, LLVMTypeConverter &typeConverter,
1513 OwningRewritePatternList &patterns) {
1514 patterns.insert<ModuleConversionPattern, ModuleEndConversionPattern>(
1515 context, typeConverter);
1516 }
1517
1518 //===----------------------------------------------------------------------===//
1519 // Pre-conversion hooks
1520 //===----------------------------------------------------------------------===//
1521
1522 /// Hook for descriptor set and binding number encoding.
1523 static constexpr StringRef kBinding = "binding";
1524 static constexpr StringRef kDescriptorSet = "descriptor_set";
encodeBindAttribute(ModuleOp module)1525 void mlir::encodeBindAttribute(ModuleOp module) {
1526 auto spvModules = module.getOps<spirv::ModuleOp>();
1527 for (auto spvModule : spvModules) {
1528 spvModule.walk([&](spirv::GlobalVariableOp op) {
1529 IntegerAttr descriptorSet =
1530 op->getAttrOfType<IntegerAttr>(kDescriptorSet);
1531 IntegerAttr binding = op->getAttrOfType<IntegerAttr>(kBinding);
1532 // For every global variable in the module, get the ones with descriptor
1533 // set and binding numbers.
1534 if (descriptorSet && binding) {
1535 // Encode these numbers into the variable's symbolic name. If the
1536 // SPIR-V module has a name, add it at the beginning.
1537 auto moduleAndName = spvModule.getName().hasValue()
1538 ? spvModule.getName().getValue().str() + "_" +
1539 op.sym_name().str()
1540 : op.sym_name().str();
1541 std::string name =
1542 llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName,
1543 std::to_string(descriptorSet.getInt()),
1544 std::to_string(binding.getInt()));
1545
1546 // Replace all symbol uses and set the new symbol name. Finally, remove
1547 // descriptor set and binding attributes.
1548 if (failed(SymbolTable::replaceAllSymbolUses(op, name, spvModule)))
1549 op.emitError("unable to replace all symbol uses for ") << name;
1550 SymbolTable::setSymbolName(op, name);
1551 op.removeAttr(kDescriptorSet);
1552 op.removeAttr(kBinding);
1553 }
1554 });
1555 }
1556 }
1557