1 //===- SPIRVConversion.h - SPIR-V Conversion Utilities ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Defines utilities to use while converting to the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_DIALECT_SPIRV_SPIRVCONVERSION_H
14 #define MLIR_DIALECT_SPIRV_SPIRVCONVERSION_H
15 
16 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
18 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 #include "llvm/ADT/SmallSet.h"
21 
22 namespace mlir {
23 
24 //===----------------------------------------------------------------------===//
25 // Type Converter
26 //===----------------------------------------------------------------------===//
27 
28 /// Type conversion from builtin types to SPIR-V types for shader interface.
29 ///
30 /// For memref types, this converter additionally performs type wrapping to
31 /// satisfy shader interface requirements: shader interface types must be
32 /// pointers to structs.
33 class SPIRVTypeConverter : public TypeConverter {
34 public:
35   struct Options {
36     /// Whether to emulate non-32-bit scalar types with 32-bit scalar types if
37     /// no native support.
38     ///
39     /// Non-32-bit scalar types require special hardware support that may not
40     /// exist on all GPUs. This is reflected in SPIR-V as that non-32-bit scalar
41     /// types require special capabilities or extensions. This option controls
42     /// whether to use 32-bit types to emulate, if a scalar type of a certain
43     /// bitwidth is not supported in the target environment. This requires the
44     /// runtime to also feed in data with a matched bitwidth and layout for
45     /// interface types. The runtime can do that by inspecting the SPIR-V
46     /// module.
47     ///
48     /// If the original scalar type has less than 32-bit, a multiple of its
49     /// values will be packed into one 32-bit value to be memory efficient.
50     bool emulateNon32BitScalarTypes;
51 
52     /// The number of bits to store a boolean value. It is eight bits by
53     /// default.
54     unsigned boolNumBits;
55 
56     // Note: we need this instead of inline initializers becuase of
57     // https://bugs.llvm.org/show_bug.cgi?id=36684
OptionsOptions58     Options() : emulateNon32BitScalarTypes(true), boolNumBits(8) {}
59   };
60 
61   explicit SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
62                               Options options = {});
63 
64   /// Gets the SPIR-V correspondence for the standard index type.
65   static Type getIndexType(MLIRContext *context);
66 
67   /// Returns the corresponding memory space for memref given a SPIR-V storage
68   /// class.
69   static unsigned getMemorySpaceForStorageClass(spirv::StorageClass);
70 
71   /// Returns the SPIR-V storage class given a memory space for memref. Return
72   /// llvm::None if the memory space does not map to any SPIR-V storage class.
73   static Optional<spirv::StorageClass>
74   getStorageClassForMemorySpace(unsigned space);
75 
76   /// Returns the options controlling the SPIR-V type converter.
77   const Options &getOptions() const;
78 
79 private:
80   spirv::TargetEnv targetEnv;
81   Options options;
82 };
83 
84 //===----------------------------------------------------------------------===//
85 // Conversion Target
86 //===----------------------------------------------------------------------===//
87 
88 // The default SPIR-V conversion target.
89 //
90 // It takes a SPIR-V target environment and controls operation legality based on
91 // the their availability in the target environment.
92 class SPIRVConversionTarget : public ConversionTarget {
93 public:
94   /// Creates a SPIR-V conversion target for the given target environment.
95   static std::unique_ptr<SPIRVConversionTarget>
96   get(spirv::TargetEnvAttr targetAttr);
97 
98 private:
99   explicit SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr);
100 
101   // Be explicit that instance of this class cannot be copied or moved: there
102   // are lambdas capturing fields of the instance.
103   SPIRVConversionTarget(const SPIRVConversionTarget &) = delete;
104   SPIRVConversionTarget(SPIRVConversionTarget &&) = delete;
105   SPIRVConversionTarget &operator=(const SPIRVConversionTarget &) = delete;
106   SPIRVConversionTarget &operator=(SPIRVConversionTarget &&) = delete;
107 
108   /// Returns true if the given `op` is legal to use under the current target
109   /// environment.
110   bool isLegalOp(Operation *op);
111 
112   spirv::TargetEnv targetEnv;
113 };
114 
115 //===----------------------------------------------------------------------===//
116 // Patterns and Utility Functions
117 //===----------------------------------------------------------------------===//
118 
119 /// Appends to a pattern list additional patterns for translating the builtin
120 /// `func` op to the SPIR-V dialect. These patterns do not handle shader
121 /// interface/ABI; they convert function parameters to be of SPIR-V allowed
122 /// types.
123 void populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
124                                         RewritePatternSet &patterns);
125 
126 namespace spirv {
127 class AccessChainOp;
128 
129 /// Returns the value for the given `builtin` variable. This function gets or
130 /// inserts the global variable associated for the builtin within the nearest
131 /// symbol table enclosing `op`. Returns null Value on error.
132 Value getBuiltinVariableValue(Operation *op, BuiltIn builtin,
133                               OpBuilder &builder);
134 
135 /// Gets the value at the given `offset` of the push constant storage with a
136 /// total of `elementCount` 32-bit integers. A global variable will be created
137 /// in the nearest symbol table enclosing `op` for the push constant storage if
138 /// not existing. Load ops will be created via the given `builder` to load
139 /// values from the push constant. Returns null Value on error.
140 Value getPushConstantValue(Operation *op, unsigned elementCount,
141                            unsigned offset, OpBuilder &builder);
142 
143 /// Generates IR to perform index linearization with the given `indices` and
144 /// their corresponding `strides`, adding an initial `offset`.
145 Value linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
146                      int64_t offset, Location loc, OpBuilder &builder);
147 
148 /// Performs the index computation to get to the element at `indices` of the
149 /// memory pointed to by `basePtr`, using the layout map of `baseType`.
150 
151 // TODO: This method assumes that the `baseType` is a MemRefType with AffineMap
152 // that has static strides. Extend to handle dynamic strides.
153 spirv::AccessChainOp getElementPtr(SPIRVTypeConverter &typeConverter,
154                                    MemRefType baseType, Value basePtr,
155                                    ValueRange indices, Location loc,
156                                    OpBuilder &builder);
157 
158 } // namespace spirv
159 } // namespace mlir
160 
161 #endif // MLIR_DIALECT_SPIRV_SPIRVCONVERSION_H
162