1 //===- SPIRVConversion.cpp - SPIR-V Conversion Utilities ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities used to lower to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16 #include "mlir/Transforms/DialectConversion.h"
17 #include "llvm/ADT/Sequence.h"
18 #include "llvm/ADT/StringExtras.h"
19 #include "llvm/Support/Debug.h"
20 
21 #include <functional>
22 
23 #define DEBUG_TYPE "mlir-spirv-conversion"
24 
25 using namespace mlir;
26 
27 //===----------------------------------------------------------------------===//
28 // Utility functions
29 //===----------------------------------------------------------------------===//
30 
31 /// Checks that `candidates` extension requirements are possible to be satisfied
32 /// with the given `targetEnv`.
33 ///
34 ///  `candidates` is a vector of vector for extension requirements following
35 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
36 /// convention.
37 template <typename LabelT>
checkExtensionRequirements(LabelT label,const spirv::TargetEnv & targetEnv,const spirv::SPIRVType::ExtensionArrayRefVector & candidates)38 static LogicalResult checkExtensionRequirements(
39     LabelT label, const spirv::TargetEnv &targetEnv,
40     const spirv::SPIRVType::ExtensionArrayRefVector &candidates) {
41   for (const auto &ors : candidates) {
42     if (targetEnv.allows(ors))
43       continue;
44 
45     SmallVector<StringRef, 4> extStrings;
46     for (spirv::Extension ext : ors)
47       extStrings.push_back(spirv::stringifyExtension(ext));
48 
49     LLVM_DEBUG(llvm::dbgs()
50                << label << " illegal: requires at least one extension in ["
51                << llvm::join(extStrings, ", ")
52                << "] but none allowed in target environment\n");
53     return failure();
54   }
55   return success();
56 }
57 
58 /// Checks that `candidates`capability requirements are possible to be satisfied
59 /// with the given `isAllowedFn`.
60 ///
61 ///  `candidates` is a vector of vector for capability requirements following
62 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
63 /// convention.
64 template <typename LabelT>
checkCapabilityRequirements(LabelT label,const spirv::TargetEnv & targetEnv,const spirv::SPIRVType::CapabilityArrayRefVector & candidates)65 static LogicalResult checkCapabilityRequirements(
66     LabelT label, const spirv::TargetEnv &targetEnv,
67     const spirv::SPIRVType::CapabilityArrayRefVector &candidates) {
68   for (const auto &ors : candidates) {
69     if (targetEnv.allows(ors))
70       continue;
71 
72     SmallVector<StringRef, 4> capStrings;
73     for (spirv::Capability cap : ors)
74       capStrings.push_back(spirv::stringifyCapability(cap));
75 
76     LLVM_DEBUG(llvm::dbgs()
77                << label << " illegal: requires at least one capability in ["
78                << llvm::join(capStrings, ", ")
79                << "] but none allowed in target environment\n");
80     return failure();
81   }
82   return success();
83 }
84 
85 //===----------------------------------------------------------------------===//
86 // Type Conversion
87 //===----------------------------------------------------------------------===//
88 
getIndexType(MLIRContext * context)89 Type SPIRVTypeConverter::getIndexType(MLIRContext *context) {
90   // Convert to 32-bit integers for now. Might need a way to control this in
91   // future.
92   // TODO: It is probably better to make it 64-bit integers. To
93   // this some support is needed in SPIR-V dialect for Conversion
94   // instructions. The Vulkan spec requires the builtins like
95   // GlobalInvocationID, etc. to be 32-bit (unsigned) integers which should be
96   // SExtended to 64-bit for index computations.
97   return IntegerType::get(context, 32);
98 }
99 
100 /// Mapping between SPIR-V storage classes to memref memory spaces.
101 ///
102 /// Note: memref does not have a defined semantics for each memory space; it
103 /// depends on the context where it is used. There are no particular reasons
104 /// behind the number assignments; we try to follow NVVM conventions and largely
105 /// give common storage classes a smaller number. The hope is use symbolic
106 /// memory space representation eventually after memref supports it.
107 // TODO: swap Generic and StorageBuffer assignment to be more akin
108 // to NVVM.
109 #define STORAGE_SPACE_MAP_LIST(MAP_FN)                                         \
110   MAP_FN(spirv::StorageClass::Generic, 1)                                      \
111   MAP_FN(spirv::StorageClass::StorageBuffer, 0)                                \
112   MAP_FN(spirv::StorageClass::Workgroup, 3)                                    \
113   MAP_FN(spirv::StorageClass::Uniform, 4)                                      \
114   MAP_FN(spirv::StorageClass::Private, 5)                                      \
115   MAP_FN(spirv::StorageClass::Function, 6)                                     \
116   MAP_FN(spirv::StorageClass::PushConstant, 7)                                 \
117   MAP_FN(spirv::StorageClass::UniformConstant, 8)                              \
118   MAP_FN(spirv::StorageClass::Input, 9)                                        \
119   MAP_FN(spirv::StorageClass::Output, 10)                                      \
120   MAP_FN(spirv::StorageClass::CrossWorkgroup, 11)                              \
121   MAP_FN(spirv::StorageClass::AtomicCounter, 12)                               \
122   MAP_FN(spirv::StorageClass::Image, 13)                                       \
123   MAP_FN(spirv::StorageClass::CallableDataNV, 14)                              \
124   MAP_FN(spirv::StorageClass::IncomingCallableDataNV, 15)                      \
125   MAP_FN(spirv::StorageClass::RayPayloadNV, 16)                                \
126   MAP_FN(spirv::StorageClass::HitAttributeNV, 17)                              \
127   MAP_FN(spirv::StorageClass::IncomingRayPayloadNV, 18)                        \
128   MAP_FN(spirv::StorageClass::ShaderRecordBufferNV, 19)                        \
129   MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 20)
130 
131 unsigned
getMemorySpaceForStorageClass(spirv::StorageClass storage)132 SPIRVTypeConverter::getMemorySpaceForStorageClass(spirv::StorageClass storage) {
133 #define STORAGE_SPACE_MAP_FN(storage, space)                                   \
134   case storage:                                                                \
135     return space;
136 
137   switch (storage) { STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) }
138 #undef STORAGE_SPACE_MAP_FN
139   llvm_unreachable("unhandled storage class!");
140 }
141 
142 Optional<spirv::StorageClass>
getStorageClassForMemorySpace(unsigned space)143 SPIRVTypeConverter::getStorageClassForMemorySpace(unsigned space) {
144 #define STORAGE_SPACE_MAP_FN(storage, space)                                   \
145   case space:                                                                  \
146     return storage;
147 
148   switch (space) {
149     STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
150   default:
151     return llvm::None;
152   }
153 #undef STORAGE_SPACE_MAP_FN
154 }
155 
156 #undef STORAGE_SPACE_MAP_LIST
157 
158 // TODO: This is a utility function that should probably be
159 // exposed by the SPIR-V dialect. Keeping it local till the use case arises.
getTypeNumBytes(Type t)160 static Optional<int64_t> getTypeNumBytes(Type t) {
161   if (t.isa<spirv::ScalarType>()) {
162     auto bitWidth = t.getIntOrFloatBitWidth();
163     // According to the SPIR-V spec:
164     // "There is no physical size or bit pattern defined for values with boolean
165     // type. If they are stored (in conjunction with OpVariable), they can only
166     // be used with logical addressing operations, not physical, and only with
167     // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
168     // Private, Function, Input, and Output."
169     if (bitWidth == 1) {
170       return llvm::None;
171     }
172     return bitWidth / 8;
173   }
174   if (auto vecType = t.dyn_cast<VectorType>()) {
175     auto elementSize = getTypeNumBytes(vecType.getElementType());
176     if (!elementSize)
177       return llvm::None;
178     return vecType.getNumElements() * *elementSize;
179   }
180   if (auto memRefType = t.dyn_cast<MemRefType>()) {
181     // TODO: Layout should also be controlled by the ABI attributes. For now
182     // using the layout from MemRef.
183     int64_t offset;
184     SmallVector<int64_t, 4> strides;
185     if (!memRefType.hasStaticShape() ||
186         failed(getStridesAndOffset(memRefType, strides, offset))) {
187       return llvm::None;
188     }
189     // To get the size of the memref object in memory, the total size is the
190     // max(stride * dimension-size) computed for all dimensions times the size
191     // of the element.
192     auto elementSize = getTypeNumBytes(memRefType.getElementType());
193     if (!elementSize) {
194       return llvm::None;
195     }
196     if (memRefType.getRank() == 0) {
197       return elementSize;
198     }
199     auto dims = memRefType.getShape();
200     if (llvm::is_contained(dims, ShapedType::kDynamicSize) ||
201         offset == MemRefType::getDynamicStrideOrOffset() ||
202         llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) {
203       return llvm::None;
204     }
205     int64_t memrefSize = -1;
206     for (auto shape : enumerate(dims)) {
207       memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]);
208     }
209     return (offset + memrefSize) * elementSize.getValue();
210   } else if (auto tensorType = t.dyn_cast<TensorType>()) {
211     if (!tensorType.hasStaticShape()) {
212       return llvm::None;
213     }
214     auto elementSize = getTypeNumBytes(tensorType.getElementType());
215     if (!elementSize) {
216       return llvm::None;
217     }
218     int64_t size = elementSize.getValue();
219     for (auto shape : tensorType.getShape()) {
220       size *= shape;
221     }
222     return size;
223   }
224   // TODO: Add size computation for other types.
225   return llvm::None;
226 }
227 
getConvertedTypeNumBytes(Type t)228 Optional<int64_t> SPIRVTypeConverter::getConvertedTypeNumBytes(Type t) {
229   return getTypeNumBytes(t);
230 }
231 
232 /// Converts a scalar `type` to a suitable type under the given `targetEnv`.
233 static Optional<Type>
convertScalarType(const spirv::TargetEnv & targetEnv,spirv::ScalarType type,Optional<spirv::StorageClass> storageClass={})234 convertScalarType(const spirv::TargetEnv &targetEnv, spirv::ScalarType type,
235                   Optional<spirv::StorageClass> storageClass = {}) {
236   // Get extension and capability requirements for the given type.
237   SmallVector<ArrayRef<spirv::Extension>, 1> extensions;
238   SmallVector<ArrayRef<spirv::Capability>, 2> capabilities;
239   type.getExtensions(extensions, storageClass);
240   type.getCapabilities(capabilities, storageClass);
241 
242   // If all requirements are met, then we can accept this type as-is.
243   if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
244       succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
245     return type;
246 
247   // Otherwise we need to adjust the type, which really means adjusting the
248   // bitwidth given this is a scalar type.
249   // TODO: We are unconditionally converting the bitwidth here,
250   // this might be okay for non-interface types (i.e., types used in
251   // Private/Function storage classes), but not for interface types (i.e.,
252   // types used in StorageBuffer/Uniform/PushConstant/etc. storage classes).
253   // This is because the later actually affects the ABI contract with the
254   // runtime. So we may want to expose a control on SPIRVTypeConverter to fail
255   // conversion if we cannot change there.
256 
257   if (auto floatType = type.dyn_cast<FloatType>()) {
258     LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
259     return Builder(targetEnv.getContext()).getF32Type();
260   }
261 
262   auto intType = type.cast<IntegerType>();
263   LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
264   return IntegerType::get(targetEnv.getContext(), /*width=*/32,
265                           intType.getSignedness());
266 }
267 
268 /// Converts a vector `type` to a suitable type under the given `targetEnv`.
269 static Optional<Type>
convertVectorType(const spirv::TargetEnv & targetEnv,VectorType type,Optional<spirv::StorageClass> storageClass={})270 convertVectorType(const spirv::TargetEnv &targetEnv, VectorType type,
271                   Optional<spirv::StorageClass> storageClass = {}) {
272   if (!spirv::CompositeType::isValid(type)) {
273     // TODO: One-element vector types can be translated into scalar
274     // types. Vector types with more than four elements can be translated into
275     // array types.
276     LLVM_DEBUG(llvm::dbgs()
277                << type << " illegal: 1- and > 4-element unimplemented\n");
278     return llvm::None;
279   }
280 
281   // Get extension and capability requirements for the given type.
282   SmallVector<ArrayRef<spirv::Extension>, 1> extensions;
283   SmallVector<ArrayRef<spirv::Capability>, 2> capabilities;
284   type.cast<spirv::CompositeType>().getExtensions(extensions, storageClass);
285   type.cast<spirv::CompositeType>().getCapabilities(capabilities, storageClass);
286 
287   // If all requirements are met, then we can accept this type as-is.
288   if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
289       succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
290     return type;
291 
292   auto elementType = convertScalarType(
293       targetEnv, type.getElementType().cast<spirv::ScalarType>(), storageClass);
294   if (elementType)
295     return VectorType::get(type.getShape(), *elementType);
296   return llvm::None;
297 }
298 
299 /// Converts a tensor `type` to a suitable type under the given `targetEnv`.
300 ///
301 /// Note that this is mainly for lowering constant tensors.In SPIR-V one can
302 /// create composite constants with OpConstantComposite to embed relative large
303 /// constant values and use OpCompositeExtract and OpCompositeInsert to
304 /// manipulate, like what we do for vectors.
convertTensorType(const spirv::TargetEnv & targetEnv,TensorType type)305 static Optional<Type> convertTensorType(const spirv::TargetEnv &targetEnv,
306                                         TensorType type) {
307   // TODO: Handle dynamic shapes.
308   if (!type.hasStaticShape()) {
309     LLVM_DEBUG(llvm::dbgs()
310                << type << " illegal: dynamic shape unimplemented\n");
311     return llvm::None;
312   }
313 
314   auto scalarType = type.getElementType().dyn_cast<spirv::ScalarType>();
315   if (!scalarType) {
316     LLVM_DEBUG(llvm::dbgs()
317                << type << " illegal: cannot convert non-scalar element type\n");
318     return llvm::None;
319   }
320 
321   Optional<int64_t> scalarSize = getTypeNumBytes(scalarType);
322   Optional<int64_t> tensorSize = getTypeNumBytes(type);
323   if (!scalarSize || !tensorSize) {
324     LLVM_DEBUG(llvm::dbgs()
325                << type << " illegal: cannot deduce element count\n");
326     return llvm::None;
327   }
328 
329   auto arrayElemCount = *tensorSize / *scalarSize;
330   auto arrayElemType = convertScalarType(targetEnv, scalarType);
331   if (!arrayElemType)
332     return llvm::None;
333   Optional<int64_t> arrayElemSize = getTypeNumBytes(*arrayElemType);
334   if (!arrayElemSize) {
335     LLVM_DEBUG(llvm::dbgs()
336                << type << " illegal: cannot deduce converted element size\n");
337     return llvm::None;
338   }
339 
340   return spirv::ArrayType::get(*arrayElemType, arrayElemCount, *arrayElemSize);
341 }
342 
convertMemrefType(const spirv::TargetEnv & targetEnv,MemRefType type)343 static Optional<Type> convertMemrefType(const spirv::TargetEnv &targetEnv,
344                                         MemRefType type) {
345   Optional<spirv::StorageClass> storageClass =
346       SPIRVTypeConverter::getStorageClassForMemorySpace(type.getMemorySpace());
347   if (!storageClass) {
348     LLVM_DEBUG(llvm::dbgs()
349                << type << " illegal: cannot convert memory space\n");
350     return llvm::None;
351   }
352 
353   Optional<Type> arrayElemType;
354   Type elementType = type.getElementType();
355   if (auto vecType = elementType.dyn_cast<VectorType>()) {
356     arrayElemType = convertVectorType(targetEnv, vecType, storageClass);
357   } else if (auto scalarType = elementType.dyn_cast<spirv::ScalarType>()) {
358     arrayElemType = convertScalarType(targetEnv, scalarType, storageClass);
359   } else {
360     LLVM_DEBUG(
361         llvm::dbgs()
362         << type
363         << " unhandled: can only convert scalar or vector element type\n");
364     return llvm::None;
365   }
366   if (!arrayElemType)
367     return llvm::None;
368 
369   Optional<int64_t> elementSize = getTypeNumBytes(elementType);
370   if (!elementSize) {
371     LLVM_DEBUG(llvm::dbgs()
372                << type << " illegal: cannot deduce element size\n");
373     return llvm::None;
374   }
375 
376   if (!type.hasStaticShape()) {
377     auto arrayType = spirv::RuntimeArrayType::get(*arrayElemType, *elementSize);
378     // Wrap in a struct to satisfy Vulkan interface requirements.
379     auto structType = spirv::StructType::get(arrayType, 0);
380     return spirv::PointerType::get(structType, *storageClass);
381   }
382 
383   Optional<int64_t> memrefSize = getTypeNumBytes(type);
384   if (!memrefSize) {
385     LLVM_DEBUG(llvm::dbgs()
386                << type << " illegal: cannot deduce element count\n");
387     return llvm::None;
388   }
389 
390   auto arrayElemCount = *memrefSize / *elementSize;
391 
392   Optional<int64_t> arrayElemSize = getTypeNumBytes(*arrayElemType);
393   if (!arrayElemSize) {
394     LLVM_DEBUG(llvm::dbgs()
395                << type << " illegal: cannot deduce converted element size\n");
396     return llvm::None;
397   }
398 
399   auto arrayType =
400       spirv::ArrayType::get(*arrayElemType, arrayElemCount, *arrayElemSize);
401 
402   // Wrap in a struct to satisfy Vulkan interface requirements. Memrefs with
403   // workgroup storage class do not need the struct to be laid out explicitly.
404   auto structType = *storageClass == spirv::StorageClass::Workgroup
405                         ? spirv::StructType::get(arrayType)
406                         : spirv::StructType::get(arrayType, 0);
407   return spirv::PointerType::get(structType, *storageClass);
408 }
409 
SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr)410 SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr)
411     : targetEnv(targetAttr) {
412   // Add conversions. The order matters here: later ones will be tried earlier.
413 
414   // All other cases failed. Then we cannot convert this type.
415   addConversion([](Type type) { return llvm::None; });
416 
417   // Allow all SPIR-V dialect specific types. This assumes all builtin types
418   // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
419   // were tried before.
420   //
421   // TODO: this assumes that the SPIR-V types are valid to use in
422   // the given target environment, which should be the case if the whole
423   // pipeline is driven by the same target environment. Still, we probably still
424   // want to validate and convert to be safe.
425   addConversion([](spirv::SPIRVType type) { return type; });
426 
427   addConversion([](IndexType indexType) {
428     return SPIRVTypeConverter::getIndexType(indexType.getContext());
429   });
430 
431   addConversion([this](IntegerType intType) -> Optional<Type> {
432     if (auto scalarType = intType.dyn_cast<spirv::ScalarType>())
433       return convertScalarType(targetEnv, scalarType);
434     return llvm::None;
435   });
436 
437   addConversion([this](FloatType floatType) -> Optional<Type> {
438     if (auto scalarType = floatType.dyn_cast<spirv::ScalarType>())
439       return convertScalarType(targetEnv, scalarType);
440     return llvm::None;
441   });
442 
443   addConversion([this](VectorType vectorType) {
444     return convertVectorType(targetEnv, vectorType);
445   });
446 
447   addConversion([this](TensorType tensorType) {
448     return convertTensorType(targetEnv, tensorType);
449   });
450 
451   addConversion([this](MemRefType memRefType) {
452     return convertMemrefType(targetEnv, memRefType);
453   });
454 }
455 
456 //===----------------------------------------------------------------------===//
457 // FuncOp Conversion Patterns
458 //===----------------------------------------------------------------------===//
459 
460 namespace {
461 /// A pattern for rewriting function signature to convert arguments of functions
462 /// to be of valid SPIR-V types.
463 class FuncOpConversion final : public OpConversionPattern<FuncOp> {
464 public:
465   using OpConversionPattern<FuncOp>::OpConversionPattern;
466 
467   LogicalResult
468   matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
469                   ConversionPatternRewriter &rewriter) const override;
470 };
471 } // namespace
472 
473 LogicalResult
matchAndRewrite(FuncOp funcOp,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const474 FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
475                                   ConversionPatternRewriter &rewriter) const {
476   auto fnType = funcOp.getType();
477   if (fnType.getNumResults() > 1)
478     return failure();
479 
480   TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
481   for (auto argType : enumerate(fnType.getInputs())) {
482     auto convertedType = getTypeConverter()->convertType(argType.value());
483     if (!convertedType)
484       return failure();
485     signatureConverter.addInputs(argType.index(), convertedType);
486   }
487 
488   Type resultType;
489   if (fnType.getNumResults() == 1)
490     resultType = getTypeConverter()->convertType(fnType.getResult(0));
491 
492   // Create the converted spv.func op.
493   auto newFuncOp = rewriter.create<spirv::FuncOp>(
494       funcOp.getLoc(), funcOp.getName(),
495       rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
496                                resultType ? TypeRange(resultType)
497                                           : TypeRange()));
498 
499   // Copy over all attributes other than the function name and type.
500   for (const auto &namedAttr : funcOp.getAttrs()) {
501     if (namedAttr.first != impl::getTypeAttrName() &&
502         namedAttr.first != SymbolTable::getSymbolAttrName())
503       newFuncOp->setAttr(namedAttr.first, namedAttr.second);
504   }
505 
506   rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
507                               newFuncOp.end());
508   if (failed(rewriter.convertRegionTypes(
509           &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))
510     return failure();
511   rewriter.eraseOp(funcOp);
512   return success();
513 }
514 
populateBuiltinFuncToSPIRVPatterns(MLIRContext * context,SPIRVTypeConverter & typeConverter,OwningRewritePatternList & patterns)515 void mlir::populateBuiltinFuncToSPIRVPatterns(
516     MLIRContext *context, SPIRVTypeConverter &typeConverter,
517     OwningRewritePatternList &patterns) {
518   patterns.insert<FuncOpConversion>(typeConverter, context);
519 }
520 
521 //===----------------------------------------------------------------------===//
522 // Builtin Variables
523 //===----------------------------------------------------------------------===//
524 
getBuiltinVariable(Block & body,spirv::BuiltIn builtin)525 static spirv::GlobalVariableOp getBuiltinVariable(Block &body,
526                                                   spirv::BuiltIn builtin) {
527   // Look through all global variables in the given `body` block and check if
528   // there is a spv.globalVariable that has the same `builtin` attribute.
529   for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
530     if (auto builtinAttr = varOp->getAttrOfType<StringAttr>(
531             spirv::SPIRVDialect::getAttributeName(
532                 spirv::Decoration::BuiltIn))) {
533       auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
534       if (varBuiltIn && varBuiltIn.getValue() == builtin) {
535         return varOp;
536       }
537     }
538   }
539   return nullptr;
540 }
541 
542 /// Gets name of global variable for a builtin.
getBuiltinVarName(spirv::BuiltIn builtin)543 static std::string getBuiltinVarName(spirv::BuiltIn builtin) {
544   return std::string("__builtin_var_") + stringifyBuiltIn(builtin).str() + "__";
545 }
546 
547 /// Gets or inserts a global variable for a builtin within `body` block.
548 static spirv::GlobalVariableOp
getOrInsertBuiltinVariable(Block & body,Location loc,spirv::BuiltIn builtin,OpBuilder & builder)549 getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
550                            OpBuilder &builder) {
551   if (auto varOp = getBuiltinVariable(body, builtin))
552     return varOp;
553 
554   OpBuilder::InsertionGuard guard(builder);
555   builder.setInsertionPointToStart(&body);
556 
557   spirv::GlobalVariableOp newVarOp;
558   switch (builtin) {
559   case spirv::BuiltIn::NumWorkgroups:
560   case spirv::BuiltIn::WorkgroupSize:
561   case spirv::BuiltIn::WorkgroupId:
562   case spirv::BuiltIn::LocalInvocationId:
563   case spirv::BuiltIn::GlobalInvocationId: {
564     auto ptrType = spirv::PointerType::get(
565         VectorType::get({3}, builder.getIntegerType(32)),
566         spirv::StorageClass::Input);
567     std::string name = getBuiltinVarName(builtin);
568     newVarOp =
569         builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
570     break;
571   }
572   case spirv::BuiltIn::SubgroupId:
573   case spirv::BuiltIn::NumSubgroups:
574   case spirv::BuiltIn::SubgroupSize: {
575     auto ptrType = spirv::PointerType::get(builder.getIntegerType(32),
576                                            spirv::StorageClass::Input);
577     std::string name = getBuiltinVarName(builtin);
578     newVarOp =
579         builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
580     break;
581   }
582   default:
583     emitError(loc, "unimplemented builtin variable generation for ")
584         << stringifyBuiltIn(builtin);
585   }
586   return newVarOp;
587 }
588 
getBuiltinVariableValue(Operation * op,spirv::BuiltIn builtin,OpBuilder & builder)589 Value mlir::spirv::getBuiltinVariableValue(Operation *op,
590                                            spirv::BuiltIn builtin,
591                                            OpBuilder &builder) {
592   Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp());
593   if (!parent) {
594     op->emitError("expected operation to be within a module-like op");
595     return nullptr;
596   }
597 
598   spirv::GlobalVariableOp varOp = getOrInsertBuiltinVariable(
599       *parent->getRegion(0).begin(), op->getLoc(), builtin, builder);
600   Value ptr = builder.create<spirv::AddressOfOp>(op->getLoc(), varOp);
601   return builder.create<spirv::LoadOp>(op->getLoc(), ptr);
602 }
603 
604 //===----------------------------------------------------------------------===//
605 // Index calculation
606 //===----------------------------------------------------------------------===//
607 
getElementPtr(SPIRVTypeConverter & typeConverter,MemRefType baseType,Value basePtr,ValueRange indices,Location loc,OpBuilder & builder)608 spirv::AccessChainOp mlir::spirv::getElementPtr(
609     SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr,
610     ValueRange indices, Location loc, OpBuilder &builder) {
611   // Get base and offset of the MemRefType and verify they are static.
612 
613   int64_t offset;
614   SmallVector<int64_t, 4> strides;
615   if (failed(getStridesAndOffset(baseType, strides, offset)) ||
616       llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()) ||
617       offset == MemRefType::getDynamicStrideOrOffset()) {
618     return nullptr;
619   }
620 
621   auto indexType = typeConverter.getIndexType(builder.getContext());
622 
623   SmallVector<Value, 2> linearizedIndices;
624   // Add a '0' at the start to index into the struct.
625   auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
626   linearizedIndices.push_back(zero);
627 
628   if (baseType.getRank() == 0) {
629     linearizedIndices.push_back(zero);
630   } else {
631     // TODO: Instead of this logic, use affine.apply and add patterns for
632     // lowering affine.apply to standard ops. These will get lowered to SPIR-V
633     // ops by the DialectConversion framework.
634     Value ptrLoc = builder.create<spirv::ConstantOp>(
635         loc, indexType, IntegerAttr::get(indexType, offset));
636     assert(indices.size() == strides.size() &&
637            "must provide indices for all dimensions");
638     for (auto index : llvm::enumerate(indices)) {
639       Value strideVal = builder.create<spirv::ConstantOp>(
640           loc, indexType, IntegerAttr::get(indexType, strides[index.index()]));
641       Value update =
642           builder.create<spirv::IMulOp>(loc, strideVal, index.value());
643       ptrLoc = builder.create<spirv::IAddOp>(loc, ptrLoc, update);
644     }
645     linearizedIndices.push_back(ptrLoc);
646   }
647   return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
648 }
649 
650 //===----------------------------------------------------------------------===//
651 // Set ABI attributes for lowering entry functions.
652 //===----------------------------------------------------------------------===//
653 
654 LogicalResult
setABIAttrs(spirv::FuncOp funcOp,spirv::EntryPointABIAttr entryPointInfo,ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo)655 mlir::spirv::setABIAttrs(spirv::FuncOp funcOp,
656                          spirv::EntryPointABIAttr entryPointInfo,
657                          ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo) {
658   // Set the attributes for argument and the function.
659   StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName();
660   for (auto argIndex : llvm::seq<unsigned>(0, argABIInfo.size())) {
661     funcOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]);
662   }
663   funcOp->setAttr(spirv::getEntryPointABIAttrName(), entryPointInfo);
664   return success();
665 }
666 
667 //===----------------------------------------------------------------------===//
668 // SPIR-V ConversionTarget
669 //===----------------------------------------------------------------------===//
670 
671 std::unique_ptr<spirv::SPIRVConversionTarget>
get(spirv::TargetEnvAttr targetAttr)672 spirv::SPIRVConversionTarget::get(spirv::TargetEnvAttr targetAttr) {
673   std::unique_ptr<SPIRVConversionTarget> target(
674       // std::make_unique does not work here because the constructor is private.
675       new SPIRVConversionTarget(targetAttr));
676   SPIRVConversionTarget *targetPtr = target.get();
677   target->addDynamicallyLegalDialect<SPIRVDialect>(
678       // We need to capture the raw pointer here because it is stable:
679       // target will be destroyed once this function is returned.
680       [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); });
681   return target;
682 }
683 
SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr)684 spirv::SPIRVConversionTarget::SPIRVConversionTarget(
685     spirv::TargetEnvAttr targetAttr)
686     : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {}
687 
isLegalOp(Operation * op)688 bool spirv::SPIRVConversionTarget::isLegalOp(Operation *op) {
689   // Make sure this op is available at the given version. Ops not implementing
690   // QueryMinVersionInterface/QueryMaxVersionInterface are available to all
691   // SPIR-V versions.
692   if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op))
693     if (minVersion.getMinVersion() > this->targetEnv.getVersion()) {
694       LLVM_DEBUG(llvm::dbgs()
695                  << op->getName() << " illegal: requiring min version "
696                  << spirv::stringifyVersion(minVersion.getMinVersion())
697                  << "\n");
698       return false;
699     }
700   if (auto maxVersion = dyn_cast<spirv::QueryMaxVersionInterface>(op))
701     if (maxVersion.getMaxVersion() < this->targetEnv.getVersion()) {
702       LLVM_DEBUG(llvm::dbgs()
703                  << op->getName() << " illegal: requiring max version "
704                  << spirv::stringifyVersion(maxVersion.getMaxVersion())
705                  << "\n");
706       return false;
707     }
708 
709   // Make sure this op's required extensions are allowed to use. Ops not
710   // implementing QueryExtensionInterface do not require extensions to be
711   // available.
712   if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
713     if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
714                                           extensions.getExtensions())))
715       return false;
716 
717   // Make sure this op's required extensions are allowed to use. Ops not
718   // implementing QueryCapabilityInterface do not require capabilities to be
719   // available.
720   if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
721     if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
722                                            capabilities.getCapabilities())))
723       return false;
724 
725   SmallVector<Type, 4> valueTypes;
726   valueTypes.append(op->operand_type_begin(), op->operand_type_end());
727   valueTypes.append(op->result_type_begin(), op->result_type_end());
728 
729   // Special treatment for global variables, whose type requirements are
730   // conveyed by type attributes.
731   if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
732     valueTypes.push_back(globalVar.type());
733 
734   // Make sure the op's operands/results use types that are allowed by the
735   // target environment.
736   SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
737   SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
738   for (Type valueType : valueTypes) {
739     typeExtensions.clear();
740     valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions);
741     if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
742                                           typeExtensions)))
743       return false;
744 
745     typeCapabilities.clear();
746     valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities);
747     if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
748                                            typeCapabilities)))
749       return false;
750   }
751 
752   return true;
753 }
754