1 //===- SPIRVConversion.h - SPIR-V Conversion Utilities ----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Defines utilities to use while converting to the SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_DIALECT_SPIRV_SPIRVCONVERSION_H 14 #define MLIR_DIALECT_SPIRV_SPIRVCONVERSION_H 15 16 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 18 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" 19 #include "mlir/Transforms/DialectConversion.h" 20 #include "llvm/ADT/SmallSet.h" 21 22 namespace mlir { 23 24 /// Type conversion from builtin types to SPIR-V types for shader interface. 25 /// 26 /// Non-32-bit scalar types require special hardware support that may not exist 27 /// on all GPUs. This is reflected in SPIR-V as that non-32-bit scalar types 28 /// require special capabilities or extensions. Right now if a scalar type of a 29 /// certain bitwidth is not supported in the target environment, we use 32-bit 30 /// ones unconditionally. This requires the runtime to also feed in data with 31 /// a matched bitwidth and layout for interface types. The runtime can do that 32 /// by inspecting the SPIR-V module. 33 /// 34 /// For memref types, this converter additionally performs type wrapping to 35 /// satisfy shader interface requirements: shader interface types must be 36 /// pointers to structs. 37 /// 38 /// TODO: We might want to introduce a way to control how unsupported bitwidth 39 /// are handled and explicitly fail if wanted. 40 class SPIRVTypeConverter : public TypeConverter { 41 public: 42 explicit SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr); 43 44 /// Gets the number of bytes used for a type when converted to SPIR-V 45 /// type. Note that it doesnt account for whether the type is legal for a 46 /// SPIR-V target (described by spirv::TargetEnvAttr). Returns None on 47 /// failure. 48 static Optional<int64_t> getConvertedTypeNumBytes(Type); 49 50 /// Gets the SPIR-V correspondence for the standard index type. 51 static Type getIndexType(MLIRContext *context); 52 53 /// Returns the corresponding memory space for memref given a SPIR-V storage 54 /// class. 55 static unsigned getMemorySpaceForStorageClass(spirv::StorageClass); 56 57 /// Returns the SPIR-V storage class given a memory space for memref. Return 58 /// llvm::None if the memory space does not map to any SPIR-V storage class. 59 static Optional<spirv::StorageClass> 60 getStorageClassForMemorySpace(unsigned space); 61 62 private: 63 spirv::TargetEnv targetEnv; 64 }; 65 66 /// Appends to a pattern list additional patterns for translating the builtin 67 /// `func` op to the SPIR-V dialect. These patterns do not handle shader 68 /// interface/ABI; they convert function parameters to be of SPIR-V allowed 69 /// types. 70 void populateBuiltinFuncToSPIRVPatterns(MLIRContext *context, 71 SPIRVTypeConverter &typeConverter, 72 OwningRewritePatternList &patterns); 73 74 namespace spirv { 75 class AccessChainOp; 76 class FuncOp; 77 78 class SPIRVConversionTarget : public ConversionTarget { 79 public: 80 /// Creates a SPIR-V conversion target for the given target environment. 81 static std::unique_ptr<SPIRVConversionTarget> get(TargetEnvAttr targetAttr); 82 83 private: 84 explicit SPIRVConversionTarget(TargetEnvAttr targetAttr); 85 86 // Be explicit that instance of this class cannot be copied or moved: there 87 // are lambdas capturing fields of the instance. 88 SPIRVConversionTarget(const SPIRVConversionTarget &) = delete; 89 SPIRVConversionTarget(SPIRVConversionTarget &&) = delete; 90 SPIRVConversionTarget &operator=(const SPIRVConversionTarget &) = delete; 91 SPIRVConversionTarget &operator=(SPIRVConversionTarget &&) = delete; 92 93 /// Returns true if the given `op` is legal to use under the current target 94 /// environment. 95 bool isLegalOp(Operation *op); 96 97 TargetEnv targetEnv; 98 }; 99 100 /// Returns the value for the given `builtin` variable. This function gets or 101 /// inserts the global variable associated for the builtin within the nearest 102 /// enclosing op that has a symbol table. Returns null Value if such an 103 /// enclosing op cannot be found. 104 Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, 105 OpBuilder &builder); 106 107 /// Performs the index computation to get to the element at `indices` of the 108 /// memory pointed to by `basePtr`, using the layout map of `baseType`. 109 110 // TODO: This method assumes that the `baseType` is a MemRefType with AffineMap 111 // that has static strides. Extend to handle dynamic strides. 112 spirv::AccessChainOp getElementPtr(SPIRVTypeConverter &typeConverter, 113 MemRefType baseType, Value basePtr, 114 ValueRange indices, Location loc, 115 OpBuilder &builder); 116 117 /// Sets the InterfaceVarABIAttr and EntryPointABIAttr for a function and its 118 /// arguments. 119 LogicalResult setABIAttrs(spirv::FuncOp funcOp, 120 EntryPointABIAttr entryPointInfo, 121 ArrayRef<InterfaceVarABIAttr> argABIInfo); 122 } // namespace spirv 123 } // namespace mlir 124 125 #endif // MLIR_DIALECT_SPIRV_SPIRVCONVERSION_H 126