1 //===- SPIRVConversion.h - SPIR-V Conversion Utilities ----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Defines utilities to use while converting to the SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef MLIR_DIALECT_SPIRV_SPIRVCONVERSION_H 14 #define MLIR_DIALECT_SPIRV_SPIRVCONVERSION_H 15 16 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 18 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" 19 #include "mlir/Transforms/DialectConversion.h" 20 #include "llvm/ADT/SmallSet.h" 21 22 namespace mlir { 23 24 //===----------------------------------------------------------------------===// 25 // Type Converter 26 //===----------------------------------------------------------------------===// 27 28 /// Type conversion from builtin types to SPIR-V types for shader interface. 29 /// 30 /// For memref types, this converter additionally performs type wrapping to 31 /// satisfy shader interface requirements: shader interface types must be 32 /// pointers to structs. 33 class SPIRVTypeConverter : public TypeConverter { 34 public: 35 struct Options { 36 /// Whether to emulate non-32-bit scalar types with 32-bit scalar types if 37 /// no native support. 38 /// 39 /// Non-32-bit scalar types require special hardware support that may not 40 /// exist on all GPUs. This is reflected in SPIR-V as that non-32-bit scalar 41 /// types require special capabilities or extensions. This option controls 42 /// whether to use 32-bit types to emulate, if a scalar type of a certain 43 /// bitwidth is not supported in the target environment. This requires the 44 /// runtime to also feed in data with a matched bitwidth and layout for 45 /// interface types. The runtime can do that by inspecting the SPIR-V 46 /// module. 47 /// 48 /// If the original scalar type has less than 32-bit, a multiple of its 49 /// values will be packed into one 32-bit value to be memory efficient. 50 bool emulateNon32BitScalarTypes; 51 52 /// The number of bits to store a boolean value. It is eight bits by 53 /// default. 54 unsigned boolNumBits; 55 56 // Note: we need this instead of inline initializers becuase of 57 // https://bugs.llvm.org/show_bug.cgi?id=36684 OptionsOptions58 Options() : emulateNon32BitScalarTypes(true), boolNumBits(8) {} 59 }; 60 61 explicit SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, 62 Options options = {}); 63 64 /// Gets the SPIR-V correspondence for the standard index type. 65 static Type getIndexType(MLIRContext *context); 66 67 /// Returns the corresponding memory space for memref given a SPIR-V storage 68 /// class. 69 static unsigned getMemorySpaceForStorageClass(spirv::StorageClass); 70 71 /// Returns the SPIR-V storage class given a memory space for memref. Return 72 /// llvm::None if the memory space does not map to any SPIR-V storage class. 73 static Optional<spirv::StorageClass> 74 getStorageClassForMemorySpace(unsigned space); 75 76 /// Returns the options controlling the SPIR-V type converter. 77 const Options &getOptions() const; 78 79 private: 80 spirv::TargetEnv targetEnv; 81 Options options; 82 }; 83 84 //===----------------------------------------------------------------------===// 85 // Conversion Target 86 //===----------------------------------------------------------------------===// 87 88 // The default SPIR-V conversion target. 89 // 90 // It takes a SPIR-V target environment and controls operation legality based on 91 // the their availability in the target environment. 92 class SPIRVConversionTarget : public ConversionTarget { 93 public: 94 /// Creates a SPIR-V conversion target for the given target environment. 95 static std::unique_ptr<SPIRVConversionTarget> 96 get(spirv::TargetEnvAttr targetAttr); 97 98 private: 99 explicit SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr); 100 101 // Be explicit that instance of this class cannot be copied or moved: there 102 // are lambdas capturing fields of the instance. 103 SPIRVConversionTarget(const SPIRVConversionTarget &) = delete; 104 SPIRVConversionTarget(SPIRVConversionTarget &&) = delete; 105 SPIRVConversionTarget &operator=(const SPIRVConversionTarget &) = delete; 106 SPIRVConversionTarget &operator=(SPIRVConversionTarget &&) = delete; 107 108 /// Returns true if the given `op` is legal to use under the current target 109 /// environment. 110 bool isLegalOp(Operation *op); 111 112 spirv::TargetEnv targetEnv; 113 }; 114 115 //===----------------------------------------------------------------------===// 116 // Patterns and Utility Functions 117 //===----------------------------------------------------------------------===// 118 119 /// Appends to a pattern list additional patterns for translating the builtin 120 /// `func` op to the SPIR-V dialect. These patterns do not handle shader 121 /// interface/ABI; they convert function parameters to be of SPIR-V allowed 122 /// types. 123 void populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter, 124 RewritePatternSet &patterns); 125 126 namespace spirv { 127 class AccessChainOp; 128 129 /// Returns the value for the given `builtin` variable. This function gets or 130 /// inserts the global variable associated for the builtin within the nearest 131 /// symbol table enclosing `op`. Returns null Value on error. 132 Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, 133 OpBuilder &builder); 134 135 /// Gets the value at the given `offset` of the push constant storage with a 136 /// total of `elementCount` 32-bit integers. A global variable will be created 137 /// in the nearest symbol table enclosing `op` for the push constant storage if 138 /// not existing. Load ops will be created via the given `builder` to load 139 /// values from the push constant. Returns null Value on error. 140 Value getPushConstantValue(Operation *op, unsigned elementCount, 141 unsigned offset, OpBuilder &builder); 142 143 /// Generates IR to perform index linearization with the given `indices` and 144 /// their corresponding `strides`, adding an initial `offset`. 145 Value linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides, 146 int64_t offset, Location loc, OpBuilder &builder); 147 148 /// Performs the index computation to get to the element at `indices` of the 149 /// memory pointed to by `basePtr`, using the layout map of `baseType`. 150 151 // TODO: This method assumes that the `baseType` is a MemRefType with AffineMap 152 // that has static strides. Extend to handle dynamic strides. 153 spirv::AccessChainOp getElementPtr(SPIRVTypeConverter &typeConverter, 154 MemRefType baseType, Value basePtr, 155 ValueRange indices, Location loc, 156 OpBuilder &builder); 157 158 } // namespace spirv 159 } // namespace mlir 160 161 #endif // MLIR_DIALECT_SPIRV_SPIRVCONVERSION_H 162